/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.ml;

import com.davidsoergel.dsutils.DSArrayUtils;
import com.google.common.base.Function;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.MapMaker;
import com.google.common.collect.Multiset;
import edu.berkeley.compbio.ml.CrossValidationResults;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultiClassCrossValidationResults<L extends Comparable>
extends CrossValidationResults {
    private static final Logger logger = Logger.getLogger(MultiClassCrossValidationResults.class);
    private int numExamples;
    private final Map<L, Multiset<L>> confusionMatrix;
    private final Multiset<L> confusionRowNull = HashMultiset.create();

    public MultiClassCrossValidationResults() {
        this.confusionMatrix = new MapMaker().makeComputingMap(new Function<L, Multiset<L>>(){

            @Override
            public Multiset<L> apply(L key) {
                return HashMultiset.create();
            }
        });
    }

    public SortedSet<L> getLabels() {
        return new TreeSet<L>(this.confusionMatrix.keySet());
    }

    public String[] getFriendlyLabels(Map<L, String> friendlyLabelMap) {
        if (friendlyLabelMap == null) {
            return null;
        }
        ArrayList<String> result = new ArrayList<String>(this.confusionMatrix.size());
        for (Comparable l : this.getLabels()) {
            result.add(friendlyLabelMap.get(l));
        }
        return result.toArray(DSArrayUtils.EMPTY_STRING_ARRAY);
    }

    public void sanityCheck() {
        int predictionCount = 0;
        for (Multiset<L> ls : this.confusionMatrix.values()) {
            predictionCount += ls.size();
        }
        assert (predictionCount == this.numExamples);
    }

    public void addSample(L realValue, L predictedValue) {
        Multiset<L> confusionRow = realValue == null ? this.confusionRowNull : this.confusionMatrix.get(realValue);
        confusionRow.add(predictedValue);
        ++this.numExamples;
    }

    @Override
    public float accuracy() {
        int correct = 0;
        for (Map.Entry<L, Multiset<L>> entry : this.confusionMatrix.entrySet()) {
            correct += entry.getValue().count(entry.getKey());
        }
        return (float)correct / (float)this.numExamples;
    }

    @Override
    public float unknown() {
        int unknown = 0;
        for (Map.Entry<L, Multiset<L>> entry : this.confusionMatrix.entrySet()) {
            unknown += entry.getValue().count(null);
        }
        return (float)unknown / (float)this.numExamples;
    }

    @Override
    public float accuracyGivenClassified() {
        int correct = 0;
        int unknown = 0;
        for (Map.Entry<L, Multiset<L>> entry : this.confusionMatrix.entrySet()) {
            correct += entry.getValue().count(entry.getKey());
            unknown += entry.getValue().count(null);
        }
        return (float)correct / ((float)this.numExamples - (float)unknown);
    }

    public float sensitivity(L label) {
        Multiset<L> predictionsForLabel = this.confusionMatrix.get(label);
        int totalWithRealLabel = predictionsForLabel.size();
        int truePositives = predictionsForLabel.count(label);
        return (float)truePositives / (float)totalWithRealLabel;
    }

    public float precision(L label) {
        Multiset<L> predictionsForLabel = this.confusionMatrix.get(label);
        int truePositives = predictionsForLabel.count(label);
        float total = this.getTotalPredicted(label);
        return total == 0.0f ? 1.0f : (float)truePositives / total;
    }

    public float[] getSpecificities() {
        float[] result = new float[this.confusionMatrix.size()];
        int i = 0;
        for (Comparable label : this.getLabels()) {
            result[i] = this.specificity(label);
            ++i;
        }
        return result;
    }

    public float[] getSensitivities() {
        float[] result = new float[this.confusionMatrix.size()];
        int i = 0;
        for (Comparable label : this.getLabels()) {
            result[i] = this.sensitivity(label);
            ++i;
        }
        return result;
    }

    public float[] getPrecisions() {
        float[] result = new float[this.confusionMatrix.size()];
        int i = 0;
        for (Comparable label : this.getLabels()) {
            result[i] = this.precision(label);
            ++i;
        }
        return result;
    }

    public float[] getPredictedCounts() {
        float[] result = new float[this.confusionMatrix.size()];
        int i = 0;
        for (Comparable label : this.getLabels()) {
            result[i] = this.getTotalPredicted(label);
            ++i;
        }
        return result;
    }

    public float[] getActualCounts() {
        float[] result = new float[this.confusionMatrix.size()];
        int i = 0;
        for (Comparable label : this.getLabels()) {
            result[i] = this.getTotalActual(label);
            ++i;
        }
        return result;
    }

    public int getCount(L actual, L predicted) {
        return this.confusionMatrix.get(actual).count(predicted);
    }

    public float specificity(L label) {
        Multiset<L> predictionsForLabel = this.confusionMatrix.get(label);
        int hasLabel = predictionsForLabel.size();
        int hasLabelRight = predictionsForLabel.count(label);
        int notLabelWrong = this.getTotalPredicted(label) - hasLabelRight;
        int notLabel = this.numExamples - hasLabel;
        int notLabelRight = notLabel - notLabelWrong;
        if (notLabel == 0) {
            return 1.0f;
        }
        return (float)notLabelRight / (float)notLabel;
    }

    public int getTotalPredicted(L label) {
        int totalWithPredictedLabel = 0;
        for (Map.Entry<L, Multiset<L>> entry : this.confusionMatrix.entrySet()) {
            totalWithPredictedLabel += entry.getValue().count(label);
        }
        return totalWithPredictedLabel;
    }

    public int getTotalActual(L label) {
        if (label == null) {
            return this.confusionRowNull.size();
        }
        return this.confusionMatrix.get(label).size();
    }

    public float classNormalizedSpecificity() {
        float sum = 0.0f;
        for (Comparable label : this.confusionMatrix.keySet()) {
            sum += this.specificity(label);
        }
        return sum / (float)this.confusionMatrix.size();
    }

    public float classNormalizedSensitivity() {
        float sum = 0.0f;
        for (Comparable label : this.confusionMatrix.keySet()) {
            sum += this.sensitivity(label);
        }
        return sum / (float)this.confusionMatrix.size();
    }

    public float classNormalizedPrecision() {
        float sum = 0.0f;
        for (Comparable label : this.confusionMatrix.keySet()) {
            float v = this.precision(label);
            if (!Double.isNaN(v)) {
                sum += v;
                continue;
            }
            logger.warn("Label " + label + " did not contribute to precision; " + this.getTotalPredicted(label) + " predictions");
        }
        return sum / (float)this.confusionMatrix.size();
    }

    public int numPopulatedRealLabels() {
        return this.confusionMatrix.size();
    }

    public int numPredictedLabels() {
        HashSet<L> predictedLabels = new HashSet<L>();
        for (Multiset<L> ls : this.confusionMatrix.values()) {
            predictedLabels.addAll(ls.elementSet());
        }
        return predictedLabels.size();
    }
}

