require 'dbi'

module TapKit

	class DBILoginPrompt < LoginPrompt
		def run
			conn = {}
			puts "Login database with DBI"

			print "URL: "
			conn['url'] = gets.chomp

			print "Username: "
			conn['user'] = gets.chomp

			print "Password: "
			conn['password'] = gets.chomp

			conn
		end
	end

	class DBIAdapter < Adapter
		class << self
			def expression_class
				DBIExpression
			end

			def adapter_name( dbi )
				case dbi.handle.class.to_s
				when 'DBI::DBD::Mysql::Database'
					'MySQL'
				when 'DBI::DBD::Pg::Database'
					'PostgreSQL'
				end
			end

			def login_prompt
				DBILoginPrompt.new
			end
		end

		def expression_factory
			DBIExpressionFactory.new self
		end
	end


	class DBIChannel < AdapterChannel
		attr_accessor :auto_commit
		attr_reader :dbi
		alias auto_commit? auto_commit

		def initialize( adapter_context )
			super
			@pk_cache = Hash.new(0)
			@dbi = nil
			@auto_commit = false
		end

		def fetch_progress?
			false
		end

		def select_attributes( attributes, lock, fetch_spec, entity )
			@attributes_to_fetch = attributes
		end

		def open?
			if @dbi then
				true
			else
				false
			end
		end

		def open
			unless open? then
				connection = @adapter_context.adapter.connection
				url  = connection['url']
				user = connection['user']
				pass = connection['password']
				@dbi = DBI.connect(url, user, pass)
				@dbi['AutoCommit'] = auto_commit
			end
		end

		def close
			if @dbi then
				@dbi.disconnect_all
			end
		end

		def evaluate( expression )
			@entity = expression.entity
			open

			bindings = []
			expression.bind_variables.each do |binding|
				binding.each do |key, value|
					if key == expression.class::VALUE_KEY then
						bindings << value
					end
				end
			end

			if application then
				if application.log_options[:sql] then
					application.log_options[:out].puts expression
				end
			end

			@pk_cache.clear
			@state = @dbi.execute(expression.statement, *bindings)
		end

		def fetch_all
			rows = []
			@state.fetch_all.each do |raw_row|
				row = {}
				raw_row.each_with_index do |value, index|
					column = @state.column_names[index]

					attr = nil
					@attributes_to_fetch.each do |_attr|
						if _attr.column_name == column then
							attr = _attr
						end
					end

					model = @adapter_context.adapter.model
					if model and attr then
						encoding = model.connection['encoding']
						row[attr.name] = attr.convert_attribute_value(value, encoding)
					elsif attr then
						row[attr.name] = attr.convert_attribute_value(value)
					else
						row[column] = value
					end
				end
				rows << row
			end
			rows
		end

		def each
			fetch_all.each do |row|
				yield row
			end
		end

		def insert_row( row, entity )
			_check_transaction
			factory = @adapter_context.adapter.expression_factory
			expr = factory.insert_statement(row, entity)
			evaluate expr
			@state.rows
		end

		def _check_transaction
			unless @adapter_context.open_transaction? then
				raise "channel's context has no transaction"
			end
		end

		def delete_rows( qualifier, entity )
			_check_transaction
			factory = @adapter_context.adapter.expression_factory
			expr = factory.delete_statement(qualifier, entity)
			evaluate expr
			@state.rows
		end

		def update_rows( row, qualifier, entity )
			_check_transaction
			factory = @adapter_context.adapter.expression_factory
			expr = factory.update_statement(row, qualifier, entity)
			evaluate expr
			@state.rows			
		end

		# Creates new rows of primary keys.
		def primary_keys_for_new_row( count, entity )
			attrs = entity.primary_key_attributes
			if attrs.size > 1 then return nil end

			open
			attr = attrs.first
			column_name = attr.column_name

			if (maxnum = @pk_cache[entity]) > 0 then
				@pk_cache[entity] = maxnum + count
				return _primary_keys(attr, maxnum, count)
			end

			sql =  "SELECT #{column_name} FROM #{entity.external_name}"
			sql << " ORDER BY #{column_name} DESC LIMIT 1"

			factory = @adapter_context.adapter.expression_factory
			expr = factory.create entity
			expr.statement = sql
			evaluate expr
			row = @state.fetch

			if row.nil? then
				maxnum = 0
			else
				maxnum = row[0]
			end

			@pk_cache[entity] = maxnum + count
			_primary_keys(attr, maxnum, count)
		end

		private

		def _primary_keys( attribute, maxnum, count )
			keys = []
			for x in 1..count
				maxnum += 1
				keys << {attribute.name => maxnum}
			end
			keys
		end

		public

		def describe_table_names
			open
			@dbi.tables
		end

		def describe_model( table_names )
			open
			adapter_name = DBIAdapter.adapter_name(@dbi)
			@adapter_class = TapKit.const_get("#{adapter_name}Adapter")

			entities = []
			table_names.each { |table| entities << _describe_entity(table) }
			entities.each   { |entity| _describe_attributes entity }

			model              = Model.new
			model.entities     = entities
			model.adapter_name = adapter_name
			model.connection   = @adapter_context.adapter.connection

			model
		end

		private

		# name, class_name, external_name
		def _describe_entity( table )
			entity = Entity.new
			entity.external_name = table
			entity.beautify_name
			entity.class_name = 'GenericRecord'
			entity
		end

		# entity: attributes, primary_key_attributes
		# attribute: class_name, external_type, width
		def _describe_attributes( entity )
			attrs = []
			@dbi.columns(entity.external_name).each do |info|
				attr = Attribute.new
				attr.column_name = info['name']
				attr.beautify_name
				attr.class_name = @adapter_class.internal_type_name info['type_name']
				attr.external_type = info['type_name']

				if attr.class_name == 'String' then
					attr.width = info['precision']
				end

				if info.primary? then
					attr.allow_null = false
					entity.primary_key_attribute_names << attr.name
				elsif not (attr.name =~ /_id$/i) then
					entity.class_property_names << attr.name
				end
				entity.add_attribute attr
			end
		end
	end


	class DBIContext < AdapterContext
		# AdapterChannel used by the receiver.
		attr_reader :channel

		def dbi
			if channel then
				channel.dbi
			else
				nil
			end
		end

		def create_channel
			channel = DBIChannel.new self
			@channels << channel
			channel
		end

		def begin_transaction
			if open_transaction? then
				raise "already begun"
			else
				@channel = nil
				@channels.each do |channel|
					unless channel.fetch_progress? then
						unless channel.open? then
							channel.open
						end
					end

					@channel = channel
					break
				end

				unless @channel then
					@channel = create_channel
					@channel.open
				end
			end

			@open_transaction     = true
			@commit_transaction   = false
			@rollback_transaction = false

			transaction_did_begin
		end

		def commit_transaction
			dbi.commit

			@open_transaction   = false
			@commit_transaction = true

			transaction_did_commit
		end

		def rollback_transaction
			dbi.rollback

			@open_transaction     = false
			@rollback_transaction = true

			transaction_did_rollback
		end

	end

	class DBIExpression < SQLExpression
		STRING_TYPE = 1
		DATE_TYPE   = 2
		NUMBER_TYPE = 3

		class << self
			def adapter_class
				DBIAdapter
			end
		end

		# string type::    sql_for_string
		# date type::      sql_for_date
		# number type::    sql_for_number
		def format_value( value, attribute )
			case self.class.adapter_class.internal_type(attribute.external_type).to_s
			when 'String'    then type = STRING_TYPE
			when 'Integer'   then type = NUMBER_TYPE
			when 'Float'     then type = NUMBER_TYPE
			when 'Date'      then type = DATE_TYPE
			when 'Time'      then type = DATE_TYPE
			when 'Timestamp' then type = DATE_TYPE
			end

			case type
			when STRING_TYPE
				sql_for_string value
			when DATE_TYPE
				sql_for_date value
			when NUMBER_TYPE
				sql_for_number value
			else
				raise "unsupported type: '#{attribute.external_type}'"
			end
		end

		def bind_variable( attribute, value )
			binding                  = {}
			binding[NAME_KEY]        = attribute.name
			binding[PLACEHOLDER_KEY] = '?'
			binding[ATTRIBUTE_KEY]   = attribute
			binding[VALUE_KEY]       = value_for_bind_variable(attribute, value)
			binding
		end

		def value_for_bind_variable( attribute, value )
			if value.nil? then
				return value
			end

			case value
			when String
				if encoding then
					value = Utilities.encode(value, encoding)
				end
				value
			when Date, Time, Timestamp
				value.to_s
			else
				value
			end
		end

		def must_use_bind_variable?( attribute )
			true
		end

		def should_use_bind_variable?( attribute )
			true
		end

		def use_bind_variables?
			true
		end

		def prepare_select( attributes, lock, fetch_spec )
			sql = super

			if limit = fetch_spec.limit then
				if limit > 0 then
					sql << " LIMIT #{limit}"
				end
			end

			@statement = sql
		end
	end

	class DBIExpressionFactory < SQLExpressionFactory
		def expression_class
			DBIExpression
		end
	end
end
