/*
 * Decompiled with CFR 0.152.
 */
package weka.experiment;

import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;
import weka.LocalString;
import weka.core.AdditionalMeasureProducer;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.experiment.CSVResultListener;
import weka.experiment.CrossValidationResultProducer;
import weka.experiment.DatabaseUtils;
import weka.experiment.ResultListener;
import weka.experiment.ResultProducer;
import weka.experiment.Stats;

public class AveragingResultProducer
implements ResultListener,
ResultProducer,
OptionHandler,
AdditionalMeasureProducer {
    protected Instances m_Instances;
    protected ResultListener m_ResultListener = new CSVResultListener();
    protected ResultProducer m_ResultProducer = new CrossValidationResultProducer();
    protected String[] m_AdditionalMeasures = null;
    protected int m_ExpectedResultsPerAverage = 10;
    protected boolean m_CalculateStdDevs;
    protected String m_CountFieldName = "Num_" + CrossValidationResultProducer.FOLD_FIELD_NAME;
    protected String m_KeyFieldName = CrossValidationResultProducer.FOLD_FIELD_NAME;
    protected int m_KeyIndex = -1;
    protected FastVector m_Keys = new FastVector();
    protected FastVector m_Results = new FastVector();
    static /* synthetic */ Class class$weka$experiment$ResultProducer;

    public String globalInfo() {
        return LocalString.get("Takes the results from a ResultProducer ") + LocalString.get("and submits the average to the result listener. Normally used with ") + LocalString.get("a CrossValidationResultProducer to perform n x m fold cross ") + "validation.";
    }

    protected int findKeyIndex() {
        block3: {
            this.m_KeyIndex = -1;
            try {
                if (this.m_ResultProducer == null) break block3;
                String[] stringArray = this.m_ResultProducer.getKeyNames();
                for (int i = 0; i < stringArray.length; ++i) {
                    if (!stringArray[i].equals(this.m_KeyFieldName)) continue;
                    this.m_KeyIndex = i;
                    break;
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        return this.m_KeyIndex;
    }

    public String[] determineColumnConstraints(ResultProducer resultProducer) throws Exception {
        return null;
    }

    protected Object[] determineTemplate(int n) throws Exception {
        if (this.m_Instances == null) {
            throw new Exception(LocalString.get("No Instances set"));
        }
        this.m_ResultProducer.setInstances(this.m_Instances);
        this.m_Keys.removeAllElements();
        this.m_Results.removeAllElements();
        this.m_ResultProducer.doRunKeys(n);
        this.checkForMultipleDifferences();
        Object[] objectArray = (Object[])((Object[])this.m_Keys.elementAt(0)).clone();
        objectArray[this.m_KeyIndex] = null;
        this.checkForDuplicateKeys(objectArray);
        return objectArray;
    }

    public void doRunKeys(int n) throws Exception {
        Object[] objectArray = this.determineTemplate(n);
        Object[] objectArray2 = new String[objectArray.length - 1];
        System.arraycopy(objectArray, 0, objectArray2, 0, this.m_KeyIndex);
        System.arraycopy(objectArray, this.m_KeyIndex + 1, objectArray2, this.m_KeyIndex, objectArray.length - this.m_KeyIndex - 1);
        this.m_ResultListener.acceptResult(this, objectArray2, null);
    }

    public void doRun(int n) throws Exception {
        Object[] objectArray = this.determineTemplate(n);
        Object[] objectArray2 = new String[objectArray.length - 1];
        System.arraycopy(objectArray, 0, objectArray2, 0, this.m_KeyIndex);
        System.arraycopy(objectArray, this.m_KeyIndex + 1, objectArray2, this.m_KeyIndex, objectArray.length - this.m_KeyIndex - 1);
        if (this.m_ResultListener.isResultRequired(this, objectArray2)) {
            this.m_Keys.removeAllElements();
            this.m_Results.removeAllElements();
            this.m_ResultProducer.doRun(n);
            this.checkForMultipleDifferences();
            objectArray = (Object[])((Object[])this.m_Keys.elementAt(0)).clone();
            objectArray[this.m_KeyIndex] = null;
            this.checkForDuplicateKeys(objectArray);
            this.doAverageResult(objectArray);
        }
    }

    protected boolean matchesTemplate(Object[] objectArray, Object[] objectArray2) {
        if (objectArray.length != objectArray2.length) {
            return false;
        }
        for (int i = 0; i < objectArray2.length; ++i) {
            if (objectArray[i] == null || objectArray[i].equals(objectArray2[i])) continue;
            return false;
        }
        return true;
    }

    protected void doAverageResult(Object[] objectArray) throws Exception {
        Object[] objectArray2 = new String[objectArray.length - 1];
        System.arraycopy(objectArray, 0, objectArray2, 0, this.m_KeyIndex);
        System.arraycopy(objectArray, this.m_KeyIndex + 1, objectArray2, this.m_KeyIndex, objectArray.length - this.m_KeyIndex - 1);
        if (this.m_ResultListener.isResultRequired(this, objectArray2)) {
            Object[] objectArray3 = this.m_ResultProducer.getResultTypes();
            Stats[] statsArray = new Stats[objectArray3.length];
            for (int i = 0; i < statsArray.length; ++i) {
                statsArray[i] = new Stats();
            }
            Object[] objectArray4 = this.getResultTypes();
            int n = 0;
            for (int i = 0; i < this.m_Keys.size(); ++i) {
                Object[] objectArray5 = (Object[])this.m_Keys.elementAt(i);
                if (!this.matchesTemplate(objectArray, objectArray5)) continue;
                Object[] objectArray6 = (Object[])this.m_Results.elementAt(i);
                ++n;
                for (int j = 0; j < objectArray3.length; ++j) {
                    if (!(objectArray3[j] instanceof Double)) continue;
                    if (objectArray6[j] == null && statsArray[j] != null) {
                        statsArray[j] = null;
                    }
                    if (statsArray[j] == null) continue;
                    double d = (Double)objectArray6[j];
                    statsArray[j].add(d);
                }
            }
            if (n != this.m_ExpectedResultsPerAverage) {
                throw new Exception(LocalString.get("Expected ") + this.m_ExpectedResultsPerAverage + LocalString.get(" results matching key \"") + DatabaseUtils.arrayToString(objectArray) + LocalString.get("\" but got ") + n);
            }
            objectArray4[0] = new Double(n);
            Object[] objectArray7 = (Object[])this.m_Results.elementAt(0);
            int n2 = 1;
            for (int i = 0; i < objectArray3.length; ++i) {
                if (objectArray3[i] instanceof Double) {
                    if (statsArray[i] != null) {
                        statsArray[i].calculateDerived();
                        objectArray4[n2++] = new Double(statsArray[i].mean);
                    } else {
                        objectArray4[n2++] = null;
                    }
                    if (!this.getCalculateStdDevs()) continue;
                    if (statsArray[i] != null) {
                        objectArray4[n2++] = new Double(statsArray[i].stdDev);
                        continue;
                    }
                    objectArray4[n2++] = null;
                    continue;
                }
                objectArray4[n2++] = objectArray7[i];
            }
            this.m_ResultListener.acceptResult(this, objectArray2, objectArray4);
        }
    }

    protected void checkForDuplicateKeys(Object[] objectArray) throws Exception {
        Hashtable<Object, Object> hashtable = new Hashtable<Object, Object>();
        int n = 0;
        for (int i = 0; i < this.m_Keys.size(); ++i) {
            Object[] objectArray2 = (Object[])this.m_Keys.elementAt(i);
            if (!this.matchesTemplate(objectArray, objectArray2)) continue;
            if (hashtable.containsKey(objectArray2[this.m_KeyIndex])) {
                throw new Exception(LocalString.get("Duplicate result received:") + DatabaseUtils.arrayToString(objectArray2));
            }
            ++n;
            hashtable.put(objectArray2[this.m_KeyIndex], objectArray2[this.m_KeyIndex]);
        }
        if (n != this.m_ExpectedResultsPerAverage) {
            throw new Exception(LocalString.get("Expected ") + this.m_ExpectedResultsPerAverage + LocalString.get(" results matching key \"") + DatabaseUtils.arrayToString(objectArray) + LocalString.get("\" but got ") + n);
        }
    }

    protected void checkForMultipleDifferences() throws Exception {
        Object[] objectArray = (Object[])this.m_Keys.elementAt(0);
        Object[] objectArray2 = (Object[])this.m_Keys.elementAt(this.m_Keys.size() - 1);
        for (int i = 0; i < objectArray.length; ++i) {
            if (i == this.m_KeyIndex || objectArray[i].equals(objectArray2[i])) continue;
            throw new Exception(LocalString.get("Keys differ on fields other than \"") + this.m_KeyFieldName + LocalString.get("\" -- time to implement multiple averaging"));
        }
    }

    public void preProcess(ResultProducer resultProducer) throws Exception {
        if (this.m_ResultListener == null) {
            throw new Exception(LocalString.get("No ResultListener set"));
        }
        this.m_ResultListener.preProcess(this);
    }

    public void preProcess() throws Exception {
        if (this.m_ResultProducer == null) {
            throw new Exception(LocalString.get("No ResultProducer set"));
        }
        this.m_ResultProducer.setResultListener(this);
        this.findKeyIndex();
        if (this.m_KeyIndex == -1) {
            throw new Exception(LocalString.get("No key field called ") + this.m_KeyFieldName + LocalString.get(" produced by ") + this.m_ResultProducer.getClass().getName());
        }
        this.m_ResultProducer.preProcess();
    }

    public void postProcess(ResultProducer resultProducer) throws Exception {
        this.m_ResultListener.postProcess(this);
    }

    public void postProcess() throws Exception {
        this.m_ResultProducer.postProcess();
    }

    public void acceptResult(ResultProducer resultProducer, Object[] objectArray, Object[] objectArray2) throws Exception {
        if (this.m_ResultProducer != resultProducer) {
            throw new Error(LocalString.get("Unrecognized ResultProducer sending results!!"));
        }
        this.m_Keys.addElement(objectArray);
        this.m_Results.addElement(objectArray2);
    }

    public boolean isResultRequired(ResultProducer resultProducer, Object[] objectArray) throws Exception {
        if (this.m_ResultProducer != resultProducer) {
            throw new Error(LocalString.get("Unrecognized ResultProducer sending results!!"));
        }
        return true;
    }

    public String[] getKeyNames() throws Exception {
        if (this.m_KeyIndex == -1) {
            throw new Exception(LocalString.get("No key field called ") + this.m_KeyFieldName + LocalString.get(" produced by ") + this.m_ResultProducer.getClass().getName());
        }
        String[] stringArray = this.m_ResultProducer.getKeyNames();
        String[] stringArray2 = new String[stringArray.length - 1];
        System.arraycopy(stringArray, 0, stringArray2, 0, this.m_KeyIndex);
        System.arraycopy(stringArray, this.m_KeyIndex + 1, stringArray2, this.m_KeyIndex, stringArray.length - this.m_KeyIndex - 1);
        return stringArray2;
    }

    public Object[] getKeyTypes() throws Exception {
        if (this.m_KeyIndex == -1) {
            throw new Exception(LocalString.get("No key field called ") + this.m_KeyFieldName + LocalString.get(" produced by ") + this.m_ResultProducer.getClass().getName());
        }
        Object[] objectArray = this.m_ResultProducer.getKeyTypes();
        Object[] objectArray2 = new String[objectArray.length - 1];
        System.arraycopy(objectArray, 0, objectArray2, 0, this.m_KeyIndex);
        System.arraycopy(objectArray, this.m_KeyIndex + 1, objectArray2, this.m_KeyIndex, objectArray.length - this.m_KeyIndex - 1);
        return objectArray2;
    }

    public String[] getResultNames() throws Exception {
        String[] stringArray = this.m_ResultProducer.getResultNames();
        if (this.getCalculateStdDevs()) {
            Object[] objectArray = this.m_ResultProducer.getResultTypes();
            int n = 0;
            for (int i = 0; i < objectArray.length; ++i) {
                if (!(objectArray[i] instanceof Double)) continue;
                ++n;
            }
            String[] stringArray2 = new String[stringArray.length + 1 + n];
            stringArray2[0] = this.m_CountFieldName;
            int n2 = 1;
            for (int i = 0; i < stringArray.length; ++i) {
                stringArray2[n2++] = "Avg_" + stringArray[i];
                if (!(objectArray[i] instanceof Double)) continue;
                stringArray2[n2++] = "Dev_" + stringArray[i];
            }
            return stringArray2;
        }
        String[] stringArray3 = new String[stringArray.length + 1];
        stringArray3[0] = this.m_CountFieldName;
        System.arraycopy(stringArray, 0, stringArray3, 1, stringArray.length);
        return stringArray3;
    }

    public Object[] getResultTypes() throws Exception {
        Object[] objectArray = this.m_ResultProducer.getResultTypes();
        if (this.getCalculateStdDevs()) {
            int n = 0;
            for (int i = 0; i < objectArray.length; ++i) {
                if (!(objectArray[i] instanceof Double)) continue;
                ++n;
            }
            Object[] objectArray2 = new Object[objectArray.length + 1 + n];
            objectArray2[0] = new Double(0.0);
            int n2 = 1;
            for (int i = 0; i < objectArray.length; ++i) {
                objectArray2[n2++] = objectArray[i];
                if (!(objectArray[i] instanceof Double)) continue;
                objectArray2[n2++] = new Double(0.0);
            }
            return objectArray2;
        }
        Object[] objectArray3 = new Object[objectArray.length + 1];
        objectArray3[0] = new Double(0.0);
        System.arraycopy(objectArray, 0, objectArray3, 1, objectArray.length);
        return objectArray3;
    }

    public String getCompatibilityState() {
        String string = " -X " + this.getExpectedResultsPerAverage() + " ";
        if (this.getCalculateStdDevs()) {
            string = string + "-S ";
        }
        string = this.m_ResultProducer == null ? string + LocalString.get("<null ResultProducer>") : string + "-W " + this.m_ResultProducer.getClass().getName();
        string = string + " -- " + this.m_ResultProducer.getCompatibilityState();
        return string.trim();
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(2);
        vector.addElement(new Option(LocalString.get("\tThe name of the field to average over.\n") + LocalString.get("\t(default \"Fold\")"), "F", 1, LocalString.get("-F <field name>")));
        vector.addElement(new Option(LocalString.get("\tThe number of results expected per average.\n") + LocalString.get("\t(default 10)"), "X", 1, LocalString.get("-X <num results>")));
        vector.addElement(new Option(LocalString.get("\tCalculate standard deviations.\n") + LocalString.get("\t(default only averages)"), "S", 0, "-S"));
        vector.addElement(new Option(LocalString.get("\tThe full class name of a ResultProducer.\n") + LocalString.get("\teg: weka.experiment.CrossValidationResultProducer"), "W", 1, LocalString.get("-W <class name>")));
        if (this.m_ResultProducer != null && this.m_ResultProducer instanceof OptionHandler) {
            vector.addElement(new Option("", "", 0, LocalString.get("\nOptions specific to result producer ") + this.m_ResultProducer.getClass().getName() + ":"));
            Enumeration enumeration = ((OptionHandler)((Object)this.m_ResultProducer)).listOptions();
            while (enumeration.hasMoreElements()) {
                vector.addElement((Option)enumeration.nextElement());
            }
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('F', stringArray);
        if (string.length() != 0) {
            this.setKeyFieldName(string);
        } else {
            this.setKeyFieldName(CrossValidationResultProducer.FOLD_FIELD_NAME);
        }
        String string2 = Utils.getOption('X', stringArray);
        if (string2.length() != 0) {
            this.setExpectedResultsPerAverage(Integer.parseInt(string2));
        } else {
            this.setExpectedResultsPerAverage(10);
        }
        this.setCalculateStdDevs(Utils.getFlag('S', stringArray));
        String string3 = Utils.getOption('W', stringArray);
        if (string3.length() == 0) {
            throw new Exception(LocalString.get("A ResultProducer must be specified with") + LocalString.get(" the -W option."));
        }
        this.setResultProducer((ResultProducer)Utils.forName(class$weka$experiment$ResultProducer == null ? (class$weka$experiment$ResultProducer = AveragingResultProducer.class$("weka.experiment.ResultProducer")) : class$weka$experiment$ResultProducer, string3, null));
        if (this.getResultProducer() instanceof OptionHandler) {
            ((OptionHandler)((Object)this.getResultProducer())).setOptions(Utils.partitionOptions(stringArray));
        }
    }

    public String[] getOptions() {
        String[] stringArray = new String[]{};
        if (this.m_ResultProducer != null && this.m_ResultProducer instanceof OptionHandler) {
            stringArray = ((OptionHandler)((Object)this.m_ResultProducer)).getOptions();
        }
        String[] stringArray2 = new String[stringArray.length + 8];
        int n = 0;
        stringArray2[n++] = "-F";
        stringArray2[n++] = "" + this.getKeyFieldName();
        stringArray2[n++] = "-X";
        stringArray2[n++] = "" + this.getExpectedResultsPerAverage();
        if (this.getCalculateStdDevs()) {
            stringArray2[n++] = "-S";
        }
        if (this.getResultProducer() != null) {
            stringArray2[n++] = "-W";
            stringArray2[n++] = this.getResultProducer().getClass().getName();
        }
        stringArray2[n++] = "--";
        System.arraycopy(stringArray, 0, stringArray2, n, stringArray.length);
        n += stringArray.length;
        while (n < stringArray2.length) {
            stringArray2[n++] = "";
        }
        return stringArray2;
    }

    public void setAdditionalMeasures(String[] stringArray) {
        this.m_AdditionalMeasures = stringArray;
        if (this.m_ResultProducer != null) {
            System.err.println(LocalString.get("AveragingResultProducer: setting additional ") + LocalString.get("measures for ") + "ResultProducer");
            this.m_ResultProducer.setAdditionalMeasures(this.m_AdditionalMeasures);
        }
    }

    public Enumeration enumerateMeasures() {
        Vector<String> vector = new Vector<String>();
        if (this.m_ResultProducer instanceof AdditionalMeasureProducer) {
            Enumeration enumeration = ((AdditionalMeasureProducer)((Object)this.m_ResultProducer)).enumerateMeasures();
            while (enumeration.hasMoreElements()) {
                String string = (String)enumeration.nextElement();
                vector.addElement(string);
            }
        }
        return vector.elements();
    }

    public double getMeasure(String string) {
        if (this.m_ResultProducer instanceof AdditionalMeasureProducer) {
            return ((AdditionalMeasureProducer)((Object)this.m_ResultProducer)).getMeasure(string);
        }
        throw new IllegalArgumentException(LocalString.get("AveragingResultProducer: ") + LocalString.get("Can't return value for : ") + string + ". " + this.m_ResultProducer.getClass().getName() + " " + LocalString.get("is not an AdditionalMeasureProducer"));
    }

    public void setInstances(Instances instances) {
        this.m_Instances = instances;
    }

    public String calculateStdDevsTipText() {
        return LocalString.get("Record standard deviations for each run.");
    }

    public boolean getCalculateStdDevs() {
        return this.m_CalculateStdDevs;
    }

    public void setCalculateStdDevs(boolean bl) {
        this.m_CalculateStdDevs = bl;
    }

    public String expectedResultsPerAverageTipText() {
        return LocalString.get("Set the expected number of results to average per run. ") + LocalString.get("For example if a CrossValidationResultProducer is being used ") + LocalString.get("(with the number of folds set to 10), then the expected number ") + LocalString.get("of results per run is 10.");
    }

    public int getExpectedResultsPerAverage() {
        return this.m_ExpectedResultsPerAverage;
    }

    public void setExpectedResultsPerAverage(int n) {
        this.m_ExpectedResultsPerAverage = n;
    }

    public String keyFieldNameTipText() {
        return LocalString.get("Set the field name that will be unique for a run.");
    }

    public String getKeyFieldName() {
        return this.m_KeyFieldName;
    }

    public void setKeyFieldName(String string) {
        this.m_KeyFieldName = string;
        this.m_CountFieldName = "Num_" + this.m_KeyFieldName;
        this.findKeyIndex();
    }

    public void setResultListener(ResultListener resultListener) {
        this.m_ResultListener = resultListener;
    }

    public String resultProducerTipText() {
        return LocalString.get("Set the resultProducer for which results are to be averaged.");
    }

    public ResultProducer getResultProducer() {
        return this.m_ResultProducer;
    }

    public void setResultProducer(ResultProducer resultProducer) {
        this.m_ResultProducer = resultProducer;
        this.m_ResultProducer.setResultListener(this);
        this.findKeyIndex();
    }

    public String toString() {
        String string = LocalString.get("AveragingResultProducer: ");
        string = string + this.getCompatibilityState();
        string = this.m_Instances == null ? string + LocalString.get(": <null Instances>") : string + ": " + Utils.backQuoteChars(this.m_Instances.relationName());
        return string;
    }

    static /* synthetic */ Class class$(String string) {
        try {
            return Class.forName(string);
        }
        catch (ClassNotFoundException classNotFoundException) {
            throw new NoClassDefFoundError(classNotFoundException.getMessage());
        }
    }
}

