/*
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied. See the License for the specific language
 * governing permissions and limitations under the License.
 */
package org.dbunitng.listeners;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Properties;

import org.apache.commons.dbcp.BasicDataSource;
import org.apache.commons.lang.StringUtils;
import org.dbunit.Assertion;
import org.dbunit.DataSourceDatabaseTester;
import org.dbunit.IDatabaseTester;
import org.dbunit.JdbcDatabaseTester;
import org.dbunit.database.QueryDataSet;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.csv.CsvDataSet;
import org.dbunit.dataset.excel.XlsDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunitng.annotations.DatabaseOperationType;
import org.dbunitng.annotations.DbUnitNG;
import org.dbunitng.annotations.FileType;
import org.dbunitng.annotations.SetUpOperation;
import org.dbunitng.annotations.TableAssert;
import org.dbunitng.annotations.TearDownOperation;
import org.dbunitng.exception.DbUnitNGRuntimeException;
import org.dbunitng.exception.TestDataFileNotFoundException;
import org.dbunitng.listeners.internal.DbUnitNGConfig;
import org.dbunitng.listeners.internal.DbUnitNGDatabaseOperation;
import org.dbunitng.util.ResourceUtil;
import org.testng.ITestContext;
import org.testng.ITestListener;
import org.testng.ITestNGMethod;
import org.testng.ITestResult;
import org.testng.TestRunner;
import org.testng.log4testng.Logger;

/**
 * TestNGとDbUnitを連携するリスナー。
 * 
 * @author jyukutyo
 * 
 */
public class DbUnitNGTestListener implements ITestListener {

	/** This class' log4testng Logger. */
	private static final Logger LOGGER =
		Logger.getLogger(DbUnitNGTestListener.class);

	/** テスター */
	private IDatabaseTester tester;

	/** DB接続情報 */
	private DbUnitNGConfig config;

	/**
	 * テストスイート開始時のコールバックメソッド。コンテキストからDB接続情報を取得する.
	 */
	public void onStart(ITestContext context) {

		LOGGER.debug("DbUnitNGTestListener#onStart");

		DbUnitNGContext.start();

		changeTestListenerOrder(context);

		Collection<ITestNGMethod> methods =
			context.getPassedConfigurations().getAllMethods();
		DbUnitNG dbUnitNG = getAnnotation(methods, DbUnitNG.class);

		if (dbUnitNG == null) {
			LOGGER.debug("Jdbc Configuration is created from testng.xml file.");
			config = DbUnitNGConfig.create(context);

		} else {
			LOGGER
				.debug("Jdbc Configuration is created from @DbUnitNG annotaion.");
			config = DbUnitNGConfig.create(dbUnitNG);
		}
		config.verifyParamsNotNull();

		createTester();
	}

	protected void changeTestListenerOrder(ITestContext context) {
		if (!(context instanceof TestRunner)) {
			return;
		}
		TestRunner runner = (TestRunner) context;
		List<ITestListener> list = runner.getTestListeners();

		Collections.sort(list, new Comparator<ITestListener>() {
			public int compare(ITestListener o1, ITestListener o2) {
				if (o1 instanceof DbUnitNGTestListener) {
					return -1;
				} else if (o2 instanceof DbUnitNGTestListener) {
					return 1;
				}
				return 0;
			}
		});
	}

	protected void createTester() {
		if (config.isDbcp()) {
			BasicDataSource source = new BasicDataSource();
			source.setDriverClassName(config.getDriver());
			source.setUrl(config.getUrl());
			source.setUsername(config.getUserName());
			source.setPassword(config.getPassword());
			tester = new DataSourceDatabaseTester(source);
			LOGGER.debug("DataSourceDatabaseTester is created.");
		} else {
			tester =
				new JdbcDatabaseTester(
					config.getDriver(),
					config.getUrl(),
					config.getUserName(),
					config.getPassword());
			LOGGER.debug("JdbcDatabaseTester is created.");
		}
	}

	/**
	 * テストメソッドのコレクションから最初に見つかったアノテーションを返す。
	 * 
	 * @param <T>
	 *            アノテーション
	 * @param methods
	 *            テストメソッド
	 * @param annotationClass
	 *            アノテーション
	 * @return アノテーション
	 */
	protected <T extends Annotation> T getAnnotation(
			Collection<ITestNGMethod> methods, Class<T> annotationClass) {

		for (ITestNGMethod testNGMethod : methods) {
			Method method = testNGMethod.getMethod();
			T annotation = method.getAnnotation(annotationClass);
			if (annotation != null) {
				return annotation;
			}
		}
		return null;

	}

	/**
	 * テストメソッドのアノテーションを返す。
	 * 
	 * @param <T>
	 *            アノテーション
	 * @param result
	 *            ITestResult
	 * @param annotationClass
	 *            アノテーションクラス
	 * @return アノテーション
	 */
	protected <T extends Annotation> T getAnnotation(ITestResult result,
			Class<T> annotationClass) {

		ITestNGMethod testNGMethod = result.getMethod();
		Method method = testNGMethod.getMethod();

		return method.getAnnotation(annotationClass);

	}

	/**
	 * テストメソッド実行前のコールバックメソッド。SetUp処理を実行する。
	 */
	public void onTestStart(ITestResult result) {

		DbUnitNGContext.startTest(result);

		SetUpOperation setUpOperation =
			this.getAnnotation(result, SetUpOperation.class);

		if (setUpOperation == null) {
			return;
		}
		DatabaseOperationType type = setUpOperation.value();
		if (type == DatabaseOperationType.USE_DEFAULT) {
			type = config.getDefaultOperation();
		}
		if (type == DatabaseOperationType.NONE
			|| type == DatabaseOperationType.USE_DEFAULT) {
			return;
		}
		String pathName = setUpOperation.pathname();
		readFileForDatabase(result, type, pathName);
	}

	protected IDataSet toDataSet(String pathName, ITestResult result) {

		String txt = "table-ordering.txt";
		int index = pathName.lastIndexOf(txt);
		if (index != -1) {
			// case expected file is CSV
			String dir;
			if (pathName.indexOf('/') < 0) {
				// case only file name is specified
				dir =
					ResourceUtil.replacePackageToDirectory(result
						.getTestClass()
						.getRealClass()
						.getPackage()
						.getName());
			} else {
				dir = pathName.substring(0, pathName.lastIndexOf(txt) - 1);
			}

			try {
				return new CsvDataSet(new File(ResourceUtil.getURI(dir)));
			} catch (DataSetException e) {
				throw new DbUnitNGRuntimeException(
					"No such directory:" + dir,
					e);
			}
		}

		String extension = ResourceUtil.getExtension(pathName);
		FileType fileType = ResourceUtil.toFileType(extension);

		InputStream stream = getFileStream(pathName, result);

		IDataSet dataSet;
		try {
			if (FileType.EXCEL == fileType) {
				dataSet = new XlsDataSet(stream);
			} else if (FileType.XML == fileType) {
				dataSet = new FlatXmlDataSet(stream);
			} else {
				throw new DbUnitNGRuntimeException("FileType is unknown. "
					+ fileType);
			}
		} catch (DataSetException e) {
			throw new DbUnitNGRuntimeException(e);
		} catch (IOException e) {
			throw new DbUnitNGRuntimeException(e);
		}
		return dataSet;
	}

	protected InputStream getFileStream(String pathName, ITestResult result) {
		if (StringUtils.isBlank(pathName)) {
			// case file path is not specified
			throw new TestDataFileNotFoundException();
		} else if (pathName.indexOf('/') < 0) {
			// case only file name is specified
			pathName =
				ResourceUtil.replacePackageToDirectory(result
					.getTestClass()
					.getRealClass()
					.getPackage()
					.getName())
					+ "/" + pathName;

		}
		InputStream stream = ResourceUtil.getResourceAsStream(pathName);
		return stream;
	}

	/**
	 * ファイルを読み込み、データベースに反映する。
	 * 
	 * @param result
	 *            ITestResult
	 * @param type
	 *            データベースへの処理方法
	 * @param pathName
	 *            ファイルパス
	 */
	protected void readFileForDatabase(ITestResult result,
			DatabaseOperationType type, String pathName) {

		IDataSet dataSet = toDataSet(pathName, result);
		DbUnitNGDatabaseOperation operation =
			new DbUnitNGDatabaseOperation(tester, config);
		operation.execute(type, dataSet);
	}

	/**
	 * テストスイート終了時のコールバックメソッド。
	 */
	public void onFinish(ITestContext context) {
		DbUnitNGContext.end();
		LOGGER.debug("DbUnitNGTestListener#onFinish");
	}

	/**
	 * TearDown処理を実行する。
	 * 
	 * @param result
	 */
	protected void onTestFinishWhateverHappens(ITestResult result) {

		DbUnitNGContext.endTest();

		TearDownOperation tearDownOperation =
			this.getAnnotation(result, TearDownOperation.class);

		if (tearDownOperation == null) {
			return;
		}
		DatabaseOperationType type = tearDownOperation.value();
		if (type == DatabaseOperationType.USE_DEFAULT) {
			type = config.getDefaultOperation();
		}
		if (type == DatabaseOperationType.NONE
			|| type == DatabaseOperationType.USE_DEFAULT) {
			return;
		}
		String pathName = tearDownOperation.pathname();
		readFileForDatabase(result, type, pathName);
	}

	public void onTestFailedButWithinSuccessPercentage(ITestResult result) {
		onTestFinishWhateverHappens(result);
	}

	public void onTestFailure(ITestResult result) {
		onTestFinishWhateverHappens(result);
	}

	public void onTestSuccess(ITestResult result) {

		TableAssert tableAssert = this.getAnnotation(result, TableAssert.class);
		if (tableAssert != null) {
			assertTable(tableAssert, result);
		}
		onTestFinishWhateverHappens(result);
	}

	protected void assertTable(TableAssert tableAssert, ITestResult result) {

		try {
			IDataSet actual = createQueryDataSet(tableAssert, result);
			IDataSet expected = toDataSet(tableAssert.pathname(), result);
			Assertion.assertEquals(expected, actual);
		} catch (Throwable e) {
			result.setStatus(ITestResult.FAILURE);
			result.setThrowable(e);
		}

	}

	protected IDataSet createQueryDataSet(TableAssert tableAssert,
			ITestResult result) {
		String[] queries = tableAssert.queries();
		String file = tableAssert.propertyFilePath();
		String[] keys = tableAssert.keys();
		String[] names = tableAssert.names();
		QueryDataSet queryDataSet;
		try {
			queryDataSet = new QueryDataSet(tester.getConnection());
		} catch (Exception e) {
			throw new DbUnitNGRuntimeException(
				"Can not create QueryDataSet.",
				e);
		}
		if (StringUtils.isBlank(file) && StringUtils.isEmpty(queries[0])) {
			for (int i = 0; i < names.length; i++) {
				queryDataSet.addTable(names[i]);
			}
		} else if (StringUtils.isBlank(file)) {

			if (queries.length != names.length) {
				throw new DbUnitNGRuntimeException(
					"In @TableAssert, queries and names must be same length. [queries.length="
						+ queries.length + "][names.length=" + names.length);
			}

			for (int i = 0; i < queries.length; i++) {
				LOGGER.info("Query:" + queries[i]);
				queryDataSet.addTable(names[i], queries[i]);
			}
		} else {

			if (names.length != keys.length) {
				throw new DbUnitNGRuntimeException(
					"should be same length @TableAssert names and keys");
			}

			Properties properties = new Properties();
			try {
				properties.load(getFileStream(file, result));
			} catch (IOException e) {
				throw new DbUnitNGRuntimeException(e);
			}
			for (int i = 0; i < names.length; i++) {
				if (!properties.containsKey(keys[i])) {
					throw new DbUnitNGRuntimeException(
						"@TableAssert property file doesn't contain the key '"
							+ keys[i] + "'.");
				}
				String query = properties.getProperty(keys[i]);
				LOGGER.info("Query:" + query);
				queryDataSet.addTable(names[i], query);
			}
		}
		return queryDataSet;
	}

	public void onTestSkipped(ITestResult result) {
	// empty
	}
}
