/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.LocalString;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.classifiers.trees.DecisionStump;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.UnsupportedClassTypeException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class LogitBoost
extends RandomizableIteratedSingleClassifierEnhancer
implements Sourcable,
WeightedInstancesHandler {
    protected Classifier[][] m_Classifiers;
    protected int m_NumClasses;
    protected int m_NumGenerated;
    protected int m_NumFolds = 0;
    protected int m_NumRuns = 1;
    protected int m_WeightThreshold = 100;
    protected static final double Z_MAX = 3.0;
    protected Instances m_NumericClassData;
    protected Attribute m_ClassAttribute;
    protected boolean m_UseResampling;
    protected double m_Precision = -1.7976931348623157E308;
    protected double m_Shrinkage = 1.0;
    protected Random m_RandomInstance = null;
    protected double m_Offset = 0.0;

    public String globalInfo() {
        return LocalString.get("Class for performing additive logistic regression. ") + LocalString.get("This class performs classification using a regression scheme as the ") + LocalString.get("base learner, and can handle multi-class problems.  For more ") + LocalString.get("information, see\n\n") + LocalString.get("Friedman, J., T. Hastie and R. Tibshirani (1998) \"Additive Logistic ") + LocalString.get("Regression: a Statistical View of Boosting\". Technical report. ") + LocalString.get("Stanford University.\n\n") + LocalString.get("Can do efficient internal cross-validation to determine ") + LocalString.get("appropriate number of iterations.");
    }

    public LogitBoost() {
        this.m_Classifier = new DecisionStump();
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances instances, double d) {
        int n = instances.numInstances();
        Instances instances2 = new Instances(instances, n);
        double[] dArray = new double[n];
        double d2 = 0.0;
        for (int i = 0; i < n; ++i) {
            dArray[i] = instances.instance(i).weight();
            d2 += dArray[i];
        }
        double d3 = d2 * d;
        int[] nArray = Utils.sort(dArray);
        d2 = 0.0;
        for (int i = n - 1; i >= 0; --i) {
            Instance instance = (Instance)instances.instance(nArray[i]).copy();
            instances2.add(instance);
            if ((d2 += dArray[nArray[i]]) > d3 && i > 0 && dArray[nArray[i]] != dArray[nArray[i - 1]]) break;
        }
        if (this.m_Debug) {
            System.err.println(LocalString.get("Selected ") + instances2.numInstances() + LocalString.get(" out of ") + n);
        }
        return instances2;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(6);
        vector.addElement(new Option(LocalString.get("\tUse resampling for boosting."), "Q", 0, "-Q"));
        vector.addElement(new Option(LocalString.get("\tPercentage of weight mass to base training on.\n") + LocalString.get("\t(default 100, reduce to around 90 speed up)"), "P", 1, LocalString.get("-P <percent>")));
        vector.addElement(new Option(LocalString.get("\tNumber of folds for internal cross-validation.\n") + LocalString.get("\t(default 0 -- no cross-validation)"), "F", 1, LocalString.get("-F <num>")));
        vector.addElement(new Option(LocalString.get("\tNumber of runs for internal cross-validation.\n") + LocalString.get("\t(default 1)"), "R", 1, LocalString.get("-R <num>")));
        vector.addElement(new Option(LocalString.get("\tThreshold on the improvement of the likelihood.\n") + LocalString.get("\t(default -Double.MAX_VALUE)"), "L", 1, LocalString.get("-L <num>")));
        vector.addElement(new Option(LocalString.get("\tShrinkage parameter.\n") + LocalString.get("\t(default 1)"), "H", 1, LocalString.get("-H <num>")));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('F', stringArray);
        if (string.length() != 0) {
            this.setNumFolds(Integer.parseInt(string));
        } else {
            this.setNumFolds(0);
        }
        String string2 = Utils.getOption('R', stringArray);
        if (string2.length() != 0) {
            this.setNumRuns(Integer.parseInt(string2));
        } else {
            this.setNumRuns(1);
        }
        String string3 = Utils.getOption('P', stringArray);
        if (string3.length() != 0) {
            this.setWeightThreshold(Integer.parseInt(string3));
        } else {
            this.setWeightThreshold(100);
        }
        String string4 = Utils.getOption('L', stringArray);
        if (string4.length() != 0) {
            this.setLikelihoodThreshold(new Double(string4));
        } else {
            this.setLikelihoodThreshold(-1.7976931348623157E308);
        }
        String string5 = Utils.getOption('H', stringArray);
        if (string5.length() != 0) {
            this.setShrinkage(new Double(string5));
        } else {
            this.setShrinkage(1.0);
        }
        this.setUseResampling(Utils.getFlag('Q', stringArray));
        if (this.m_UseResampling && string3.length() != 0) {
            throw new Exception(LocalString.get("Weight pruning with resampling") + LocalString.get("not allowed."));
        }
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = super.getOptions();
        String[] stringArray2 = new String[stringArray.length + 10];
        int n = 0;
        if (this.getUseResampling()) {
            stringArray2[n++] = "-Q";
        } else {
            stringArray2[n++] = "-P";
            stringArray2[n++] = "" + this.getWeightThreshold();
        }
        stringArray2[n++] = "-F";
        stringArray2[n++] = "" + this.getNumFolds();
        stringArray2[n++] = "-R";
        stringArray2[n++] = "" + this.getNumRuns();
        stringArray2[n++] = "-L";
        stringArray2[n++] = "" + this.getLikelihoodThreshold();
        stringArray2[n++] = "-H";
        stringArray2[n++] = "" + this.getShrinkage();
        System.arraycopy(stringArray, 0, stringArray2, n, stringArray.length);
        n += stringArray.length;
        while (n < stringArray2.length) {
            stringArray2[n++] = "";
        }
        return stringArray2;
    }

    public String shrinkageTipText() {
        return LocalString.get("Shrinkage parameter (use small value like 0.1 to reduce ") + "overfitting).";
    }

    public double getShrinkage() {
        return this.m_Shrinkage;
    }

    public void setShrinkage(double d) {
        this.m_Shrinkage = d;
    }

    public String likelihoodThresholdTipText() {
        return LocalString.get("Threshold on improvement in likelihood.");
    }

    public double getLikelihoodThreshold() {
        return this.m_Precision;
    }

    public void setLikelihoodThreshold(double d) {
        this.m_Precision = d;
    }

    public String numRunsTipText() {
        return LocalString.get("Number of runs for internal cross-validation.");
    }

    public int getNumRuns() {
        return this.m_NumRuns;
    }

    public void setNumRuns(int n) {
        this.m_NumRuns = n;
    }

    public String numFoldsTipText() {
        return LocalString.get("Number of folds for internal cross-validation (default 0 ") + LocalString.get("means no cross-validation is performed).");
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setNumFolds(int n) {
        this.m_NumFolds = n;
    }

    public String useResamplingTipText() {
        return LocalString.get("Whether resampling is used instead of reweighting.");
    }

    public void setUseResampling(boolean bl) {
        this.m_UseResampling = bl;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    public String weightThresholdTipText() {
        return LocalString.get("Weight threshold for weight pruning (reduce to 90 ") + LocalString.get("for speeding up learning process).");
    }

    public void setWeightThreshold(int n) {
        this.m_WeightThreshold = n;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        int n2;
        this.m_RandomInstance = new Random(this.m_Seed);
        int n3 = instances.classIndex();
        if (instances.classAttribute().isNumeric()) {
            throw new UnsupportedClassTypeException(LocalString.get("LogitBoost can't handle a numeric class!"));
        }
        if (this.m_Classifier == null) {
            throw new Exception(LocalString.get("A base classifier has not been specified!"));
        }
        if (!(this.m_Classifier instanceof WeightedInstancesHandler) && !this.m_UseResampling) {
            this.m_UseResampling = true;
        }
        if (instances.checkForStringAttributes()) {
            throw new UnsupportedAttributeTypeException(LocalString.get("Cannot handle string attributes!"));
        }
        if (this.m_Debug) {
            System.err.println(LocalString.get("Creating copy of the training data"));
        }
        this.m_NumClasses = instances.numClasses();
        this.m_ClassAttribute = instances.classAttribute();
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        if (this.m_Debug) {
            System.err.println(LocalString.get("Creating base classifiers"));
        }
        this.m_Classifiers = new Classifier[this.m_NumClasses][];
        for (n2 = 0; n2 < this.m_NumClasses; ++n2) {
            this.m_Classifiers[n2] = Classifier.makeCopies(this.m_Classifier, this.getNumIterations());
        }
        n2 = this.getNumIterations();
        if (this.m_NumFolds > 1) {
            if (this.m_Debug) {
                System.err.println(LocalString.get("Processing first fold."));
            }
            double[] dArray = new double[this.getNumIterations()];
            for (int i = 0; i < this.m_NumRuns; ++i) {
                instances.randomize(this.m_RandomInstance);
                instances.stratify(this.m_NumFolds);
                for (int j = 0; j < this.m_NumFolds; ++j) {
                    Instances instances2 = instances.trainCV(this.m_NumFolds, j, this.m_RandomInstance);
                    Instances instances3 = instances.testCV(this.m_NumFolds, j);
                    Instances instances4 = new Instances(instances2);
                    instances4.setClassIndex(-1);
                    instances4.deleteAttributeAt(n3);
                    instances4.insertAttributeAt(new Attribute(LocalString.get("'pseudo class'")), n3);
                    instances4.setClassIndex(n3);
                    this.m_NumericClassData = new Instances(instances4, 0);
                    int n4 = instances2.numInstances();
                    double[][] dArray2 = new double[n4][this.m_NumClasses];
                    double[][] dArray3 = new double[n4][this.m_NumClasses];
                    for (int k = 0; k < this.m_NumClasses; ++k) {
                        for (int i2 = 0; i2 < n4; ++i2) {
                            dArray3[i2][k] = instances2.instance(i2).classValue() == (double)k ? 1.0 - this.m_Offset : 0.0 + this.m_Offset / (double)this.m_NumClasses;
                        }
                    }
                    double[][] dArray4 = this.initialProbs(n4);
                    this.m_NumGenerated = 0;
                    double d = instances2.sumOfWeights();
                    int n5 = 0;
                    while (n5 < this.getNumIterations()) {
                        this.performIteration(dArray3, dArray2, dArray4, instances4, d);
                        Evaluation evaluation = new Evaluation(instances2);
                        evaluation.evaluateModel(this, instances3);
                        int n6 = n5++;
                        dArray[n6] = dArray[n6] + evaluation.correct();
                    }
                }
            }
            double d = -1.7976931348623157E308;
            for (n = 0; n < this.getNumIterations(); ++n) {
                if (!(dArray[n] > d)) continue;
                d = dArray[n];
                n2 = n;
            }
            if (this.m_Debug) {
                System.err.println(LocalString.get("Best result for ") + n2 + LocalString.get(" iterations: ") + d);
            }
        }
        int n7 = instances.numInstances();
        double[][] dArray = new double[n7][this.m_NumClasses];
        double[][] dArray5 = new double[n7][this.m_NumClasses];
        for (n = 0; n < this.m_NumClasses; ++n) {
            int n8 = 0;
            int n9 = 0;
            while (n8 < n7) {
                dArray5[n8][n] = instances.instance(n9).classValue() == (double)n ? 1.0 - this.m_Offset : 0.0 + this.m_Offset / (double)this.m_NumClasses;
                ++n8;
                ++n9;
            }
        }
        instances.setClassIndex(-1);
        instances.deleteAttributeAt(n3);
        instances.insertAttributeAt(new Attribute(LocalString.get("'pseudo class'")), n3);
        instances.setClassIndex(n3);
        this.m_NumericClassData = new Instances(instances, 0);
        double[][] dArray6 = this.initialProbs(n7);
        double d = this.logLikelihood(dArray5, dArray6);
        this.m_NumGenerated = 0;
        if (this.m_Debug) {
            System.err.println(LocalString.get("Avg. log-likelihood: ") + d);
        }
        double d2 = instances.sumOfWeights();
        for (int i = 0; i < n2; ++i) {
            double d3 = d;
            this.performIteration(dArray5, dArray, dArray6, instances, d2);
            d = this.logLikelihood(dArray5, dArray6);
            if (this.m_Debug) {
                System.err.println(LocalString.get("Avg. log-likelihood: ") + d);
            }
            if (!(Math.abs(d3 - d) < this.m_Precision)) continue;
            return;
        }
    }

    private double[][] initialProbs(int n) {
        double[][] dArray = new double[n][this.m_NumClasses];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < this.m_NumClasses; ++j) {
                dArray[i][j] = 1.0 / (double)this.m_NumClasses;
            }
        }
        return dArray;
    }

    private double logLikelihood(double[][] dArray, double[][] dArray2) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            for (int j = 0; j < this.m_NumClasses; ++j) {
                if (dArray[i][j] != 1.0 - this.m_Offset) continue;
                d -= Math.log(dArray2[i][j]);
            }
        }
        return d / (double)dArray.length;
    }

    private void performIteration(double[][] dArray, double[][] dArray2, double[][] dArray3, Instances instances, double d) throws Exception {
        double d2;
        Object object;
        int n;
        if (this.m_Debug) {
            System.err.println(LocalString.get("Training classifier ") + (this.m_NumGenerated + 1));
        }
        for (n = 0; n < this.m_NumClasses; ++n) {
            if (this.m_Debug) {
                System.err.println(LocalString.get("\t...for class ") + (n + 1) + " (" + this.m_ClassAttribute.name() + "=" + this.m_ClassAttribute.value(n) + ")");
            }
            object = new Instances(instances);
            for (int i = 0; i < dArray3.length; ++i) {
                double d3;
                double d4 = dArray3[i][n];
                double d5 = dArray[i][n];
                if (d5 == 1.0 - this.m_Offset) {
                    d3 = 1.0 / d4;
                    if (d3 > 3.0) {
                        d3 = 3.0;
                    }
                } else {
                    d3 = -1.0 / (1.0 - d4);
                    if (d3 < -3.0) {
                        d3 = -3.0;
                    }
                }
                double d6 = (d5 - d4) / d3;
                Instance instance = ((Instances)object).instance(i);
                instance.setValue(((Instances)object).classIndex(), d3);
                instance.setWeight(instance.weight() * d6);
            }
            d2 = ((Instances)object).sumOfWeights();
            double d7 = d / d2;
            for (int i = 0; i < dArray3.length; ++i) {
                Instance instance = ((Instances)object).instance(i);
                instance.setWeight(instance.weight() * d7);
            }
            Object object2 = object;
            if (this.m_WeightThreshold < 100) {
                object2 = this.selectWeightQuantile((Instances)object, (double)this.m_WeightThreshold / 100.0);
            } else if (this.m_UseResampling) {
                double[] dArray4 = new double[((Instances)object).numInstances()];
                for (int i = 0; i < dArray4.length; ++i) {
                    dArray4[i] = ((Instances)object).instance(i).weight();
                }
                object2 = ((Instances)object).resampleWithWeights(this.m_RandomInstance, dArray4);
            }
            this.m_Classifiers[n][this.m_NumGenerated].buildClassifier((Instances)object2);
        }
        for (n = 0; n < dArray2.length; ++n) {
            int n2;
            object = new double[this.m_NumClasses];
            d2 = 0.0;
            for (n2 = 0; n2 < this.m_NumClasses; ++n2) {
                object[n2] = this.m_Shrinkage * this.m_Classifiers[n2][this.m_NumGenerated].classifyInstance(instances.instance(n));
                d2 += object[n2];
            }
            d2 /= (double)this.m_NumClasses;
            for (n2 = 0; n2 < this.m_NumClasses; ++n2) {
                double[] dArray5 = dArray2[n];
                int n3 = n2;
                dArray5[n3] = dArray5[n3] + (object[n2] - d2) * (double)(this.m_NumClasses - 1) / (double)this.m_NumClasses;
            }
        }
        ++this.m_NumGenerated;
        for (n = 0; n < dArray.length; ++n) {
            dArray3[n] = this.probs(dArray2[n]);
        }
    }

    public Classifier[][] classifiers() {
        Classifier[][] classifierArray = new Classifier[this.m_NumClasses][this.m_NumGenerated];
        for (int i = 0; i < this.m_NumClasses; ++i) {
            for (int j = 0; j < this.m_NumGenerated; ++j) {
                classifierArray[i][j] = this.m_Classifiers[i][j];
            }
        }
        return classifierArray;
    }

    private double[] probs(double[] dArray) {
        double d = -1.7976931348623157E308;
        for (int i = 0; i < dArray.length; ++i) {
            if (!(dArray[i] > d)) continue;
            d = dArray[i];
        }
        double d2 = 0.0;
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray2[i] = Math.exp(dArray[i] - d);
            d2 += dArray2[i];
        }
        Utils.normalize(dArray2, d2);
        return dArray2;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        instance = (Instance)instance.copy();
        instance.setDataset(this.m_NumericClassData);
        double[] dArray = new double[this.m_NumClasses];
        double[] dArray2 = new double[this.m_NumClasses];
        for (int i = 0; i < this.m_NumGenerated; ++i) {
            int n;
            double d = 0.0;
            for (n = 0; n < this.m_NumClasses; ++n) {
                dArray[n] = this.m_Classifiers[n][i].classifyInstance(instance);
                d += dArray[n];
            }
            d /= (double)this.m_NumClasses;
            for (n = 0; n < this.m_NumClasses; ++n) {
                int n2 = n;
                dArray2[n2] = dArray2[n2] + (dArray[n] - d) * (double)(this.m_NumClasses - 1) / (double)this.m_NumClasses;
            }
        }
        return this.probs(dArray2);
    }

    public String toSource(String string) throws Exception {
        int n;
        int n2;
        if (this.m_NumGenerated == 0) {
            throw new Exception(LocalString.get("No model built yet"));
        }
        if (!(this.m_Classifiers[0][0] instanceof Sourcable)) {
            throw new Exception(LocalString.get("Base learner ") + this.m_Classifier.getClass().getName() + LocalString.get(" is not Sourcable"));
        }
        StringBuffer stringBuffer = new StringBuffer(LocalString.get("class "));
        stringBuffer.append(string).append(" {\n\n");
        stringBuffer.append(LocalString.get("  private static double RtoP(double []R, int j) {\n") + LocalString.get("    double Rcenter = 0;\n") + LocalString.get("    for (int i = 0; i < R.length; i++) {\n") + LocalString.get("      Rcenter += R[i];\n") + "    }\n" + LocalString.get("    Rcenter /= R.length;\n") + LocalString.get("    double Rsum = 0;\n") + LocalString.get("    for (int i = 0; i < R.length; i++) {\n") + LocalString.get("      Rsum += Math.exp(R[i] - Rcenter);\n") + "    }\n" + LocalString.get("    return Math.exp(R[j]) / Rsum;\n") + "  }\n\n");
        stringBuffer.append(LocalString.get("  public static double classify(Object [] i) {\n") + LocalString.get("    double [] d = distribution(i);\n") + LocalString.get("    double maxV = d[0];\n") + LocalString.get("    int maxI = 0;\n") + LocalString.get("    for (int j = 1; j < ") + this.m_NumClasses + "; j++) {\n" + LocalString.get("      if (d[j] > maxV) { maxV = d[j]; maxI = j; }\n") + LocalString.get("    }\n    return (double) maxI;\n  }\n\n"));
        stringBuffer.append(LocalString.get("  public static double [] distribution(Object [] i) {\n"));
        stringBuffer.append(LocalString.get("    double [] Fs = new double [") + this.m_NumClasses + "];\n");
        stringBuffer.append(LocalString.get("    double [] Fi = new double [") + this.m_NumClasses + "];\n");
        stringBuffer.append(LocalString.get("    double Fsum;\n"));
        for (n2 = 0; n2 < this.m_NumGenerated; ++n2) {
            stringBuffer.append(LocalString.get("    Fsum = 0;\n"));
            for (n = 0; n < this.m_NumClasses; ++n) {
                stringBuffer.append("    Fi[" + n + "] = " + string + '_' + n + '_' + n2 + LocalString.get(".classify(i); Fsum += Fi[") + n + "];\n");
            }
            stringBuffer.append(LocalString.get("    Fsum /= ") + this.m_NumClasses + ";\n");
            stringBuffer.append(LocalString.get("    for (int j = 0; j < ") + this.m_NumClasses + "; j++) {");
            stringBuffer.append(LocalString.get(" Fs[j] += (Fi[j] - Fsum) * ") + (this.m_NumClasses - 1) + " / " + this.m_NumClasses + "; }\n");
        }
        stringBuffer.append(LocalString.get("    double [] dist = new double [") + this.m_NumClasses + "];\n" + LocalString.get("    for (int j = 0; j < ") + this.m_NumClasses + "; j++) {\n" + LocalString.get("      dist[j] = RtoP(Fs, j);\n") + LocalString.get("    }\n    return dist;\n"));
        stringBuffer.append("  }\n}\n");
        for (n2 = 0; n2 < this.m_Classifiers.length; ++n2) {
            for (n = 0; n < this.m_Classifiers[n2].length; ++n) {
                stringBuffer.append(((Sourcable)((Object)this.m_Classifiers[n2][n])).toSource(string + '_' + n2 + '_' + n));
            }
        }
        return stringBuffer.toString();
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_NumGenerated == 0) {
            stringBuffer.append(LocalString.get("LogitBoost: No model built yet."));
        } else {
            stringBuffer.append(LocalString.get("LogitBoost: Base classifiers and their weights: \n"));
            for (int i = 0; i < this.m_NumGenerated; ++i) {
                stringBuffer.append(LocalString.get("\nIteration ") + (i + 1));
                for (int j = 0; j < this.m_NumClasses; ++j) {
                    stringBuffer.append(LocalString.get("\n\tClass ") + (j + 1) + " (" + this.m_ClassAttribute.name() + "=" + this.m_ClassAttribute.value(j) + ")\n\n" + this.m_Classifiers[j][i].toString() + "\n");
                }
            }
            stringBuffer.append(LocalString.get("Number of performed iterations: ") + this.m_NumGenerated + "\n");
        }
        return stringBuffer.toString();
    }

    public static void main(String[] stringArray) {
        try {
            System.out.println(Evaluation.evaluateModel(new LogitBoost(), stringArray));
        }
        catch (Exception exception) {
            exception.printStackTrace();
            System.err.println(exception.getMessage());
        }
    }
}

