/*
 * Copyright 2009-2010 the Fess Project and the Others.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
 * either express or implied. See the License for the specific language
 * governing permissions and limitations under the License.
 */

package jp.sf.fess.servlet;

import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;

import org.h2.tools.Server;
import org.seasar.framework.util.DriverManagerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FessConfigServlet extends HttpServlet {
    private static final Logger logger = LoggerFactory
            .getLogger(FessConfigServlet.class);

    private static final long serialVersionUID = 1L;

    protected Server server = null;

    @Override
    public void init() throws ServletException {
        List<String> argList = new ArrayList<String>();
        String value;

        try {
            argList.add("-baseDir");
            value = getServletConfig().getInitParameter("baseDir");
            if (value != null) {
                argList.add(value);
            } else {
                argList.add(getServletContext().getRealPath("/WEB-INF/db/"));
            }

            argList.add("-tcp");

            value = getServletConfig().getInitParameter("tcpAllowOthers");
            if (value != null && "true".equalsIgnoreCase(value)) {
                argList.add("-tcpAllowOthers");
            }

            value = getServletConfig().getInitParameter("tcpPort");
            if (value != null) {
                argList.add("-tcpPort");
                argList.add(value);
            }

            value = getServletConfig().getInitParameter("tcpSSL");
            if (value != null && "true".equalsIgnoreCase(value)) {
                argList.add("-tcpSSL");
            }

            value = getServletConfig().getInitParameter("tcpPassword");
            if (value != null) {
                argList.add("-tcpPassword");
                argList.add(value);
            }

            server = Server.createTcpServer(
                    argList.toArray(new String[argList.size()])).start();
        } catch (Exception e) {
            throw new ServletException("Could not start Fess Config DB.", e);
        }
    }

    @Override
    public void destroy() {
        if (server != null) {
            server.stop();
        }

        if (logger.isInfoEnabled()) {
            logger.info("Removing all drivers...");
        }
        DriverManagerUtil.deregisterAllDrivers();

        cleanupAllThreads();
    }

    private void cleanupAllThreads() {
        Thread[] threads = getThreads();
        ClassLoader cl = this.getClass().getClassLoader();
        try {
            cl.getResource(null);
        } catch (Exception e) {
        }

        List<String> jvmThreadGroupList = new ArrayList<String>();
        jvmThreadGroupList.add("system");
        jvmThreadGroupList.add("RMI Runtime");

        // Iterate over the set of threads
        for (Thread thread : threads) {
            if (thread != null) {
                ClassLoader ccl = thread.getContextClassLoader();
                if (ccl != null && ccl == cl) {
                    // Don't warn about this thread
                    if (thread == Thread.currentThread()) {
                        continue;
                    }

                    // Don't warn about JVM controlled threads
                    ThreadGroup tg = thread.getThreadGroup();
                    if (tg != null && jvmThreadGroupList.contains(tg.getName())) {
                        continue;
                    }

                    waitThread(thread);
                    // Skip threads that have already died
                    if (!thread.isAlive()) {
                        continue;
                    }

                    if (logger.isInfoEnabled()) {
                        logger.info("Interrupting a thread ["
                                + thread.getName() + "]...");
                    }
                    thread.interrupt();

                    waitThread(thread);
                    // Skip threads that have already died
                    if (!thread.isAlive()) {
                        continue;
                    }

                    if (logger.isInfoEnabled()) {
                        logger.info("Stopping a thread [" + thread.getName()
                                + "]...");
                    }
                    thread.stop();
                }
            }
        }

        Field threadLocalsField = null;
        Field inheritableThreadLocalsField = null;
        Field tableField = null;
        try {
            threadLocalsField = Thread.class.getDeclaredField("threadLocals");
            threadLocalsField.setAccessible(true);
            inheritableThreadLocalsField = Thread.class
                    .getDeclaredField("inheritableThreadLocals");
            inheritableThreadLocalsField.setAccessible(true);
            // Make the underlying array of ThreadLoad.ThreadLocalMap.Entry objects
            // accessible
            Class<?> tlmClass = Class
                    .forName("java.lang.ThreadLocal$ThreadLocalMap");
            tableField = tlmClass.getDeclaredField("table");
            tableField.setAccessible(true);
        } catch (Exception e) {
            // ignore
        }
        for (Thread thread : threads) {
            if (thread != null) {

                Object threadLocalMap;
                try {
                    // Clear the first map
                    threadLocalMap = threadLocalsField.get(thread);
                    clearThreadLocalMap(cl, threadLocalMap, tableField);
                } catch (Exception e) {
                    // ignore
                }
                try { // Clear the second map
                    threadLocalMap = inheritableThreadLocalsField.get(thread);
                    clearThreadLocalMap(cl, threadLocalMap, tableField);
                } catch (Exception e) {
                    // ignore
                }
            }
        }
    }

    private void waitThread(Thread thread) {
        int count = 0;
        while (thread.isAlive() && count < 5) {
            try {
                Thread.sleep(100);
            } catch (InterruptedException e) {
            }
            count++;
        }
    }

    /*
     * Get the set of current threads as an array.
     */
    private Thread[] getThreads() {
        // Get the current thread group
        ThreadGroup tg = Thread.currentThread().getThreadGroup();
        // Find the root thread group
        while (tg.getParent() != null) {
            tg = tg.getParent();
        }

        int threadCountGuess = tg.activeCount() + 50;
        Thread[] threads = new Thread[threadCountGuess];
        int threadCountActual = tg.enumerate(threads);
        // Make sure we don't miss any threads
        while (threadCountActual == threadCountGuess) {
            threadCountGuess *= 2;
            threads = new Thread[threadCountGuess];
            // Note tg.enumerate(Thread[]) silently ignores any threads that
            // can't fit into the array
            threadCountActual = tg.enumerate(threads);
        }

        return threads;
    }

    private void clearThreadLocalMap(ClassLoader cl, Object map,
            Field internalTableField) throws NoSuchMethodException,
            IllegalAccessException, NoSuchFieldException,
            InvocationTargetException {
        if (map != null) {
            Method mapRemove = map.getClass().getDeclaredMethod("remove",
                    ThreadLocal.class);
            mapRemove.setAccessible(true);
            Object[] table = (Object[]) internalTableField.get(map);
            if (table != null) {
                for (int j = 0; j < table.length; j++) {
                    if (table[j] != null) {
                        boolean remove = false;
                        // Check the key
                        Field keyField = Reference.class
                                .getDeclaredField("referent");
                        keyField.setAccessible(true);
                        Object key = keyField.get(table[j]);
                        if (cl.equals(key)
                                || (key != null && cl == key.getClass()
                                        .getClassLoader())) {
                            remove = true;
                        }
                        // Check the value
                        Field valueField = table[j].getClass()
                                .getDeclaredField("value");
                        valueField.setAccessible(true);
                        Object value = valueField.get(table[j]);
                        if (cl.equals(value)
                                || (value != null && cl == value.getClass()
                                        .getClassLoader())) {
                            remove = true;
                        }
                        if (remove) {
                            Object entry = ((Reference<?>) table[j]).get();
                            if (logger.isInfoEnabled()) {
                                logger.info("Removing " + key.toString()
                                        + " from a thread local...");
                            }
                            mapRemove.invoke(map, entry);
                        }
                    }
                }
            }
        }
    }

}
