/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.jlibsvm.multi;

import com.google.common.base.Function;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.MapMaker;
import edu.berkeley.compbio.jlibsvm.DiscreteModel;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.SolutionModel;
import edu.berkeley.compbio.jlibsvm.SvmException;
import edu.berkeley.compbio.jlibsvm.binary.BinaryModel;
import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;
import edu.berkeley.compbio.jlibsvm.multi.SvmMultiClassCrossValidationResults;
import edu.berkeley.compbio.jlibsvm.multi.VotingResult;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import edu.berkeley.compbio.jlibsvm.scaler.ScalingModel;
import edu.berkeley.compbio.ml.MultiClassCrossValidationResults;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultiClassModel<L extends Comparable, P>
extends SolutionModel<L, P>
implements DiscreteModel<L, P> {
    private static final Logger logger = Logger.getLogger(MultiClassModel.class);
    private ScalingModel<P> scalingModel = new NoopScalingModel();
    private final OneVsAllMode oneVsAllMode;
    private final double oneVsAllThreshold;
    private final AllVsAllMode allVsAllMode;
    private final double minVoteProportion;
    private final Map<BinaryModel<L, P>, int[]> svIndexMaps;
    private final int numberOfClasses;
    private final SymmetricHashMap2d<L, BinaryModel<L, P>> oneVsOneModels;
    private final HashMap<L, BinaryModel<L, P>> oneVsAllModels;
    private P[] allSVs;
    SvmMultiClassCrossValidationResults<L, P> crossValidationResults;

    @Override
    public MultiClassCrossValidationResults<L> getCrossValidationResults() {
        return this.crossValidationResults;
    }

    public MultiClassModel(MultiClassModel<L, P> copyFrom, Collection<L> excludeLabels) {
        this.allSVs = copyFrom.allSVs;
        this.oneVsAllMode = copyFrom.oneVsAllMode;
        this.oneVsAllThreshold = copyFrom.oneVsAllThreshold;
        this.allVsAllMode = copyFrom.allVsAllMode;
        this.minVoteProportion = copyFrom.minVoteProportion;
        this.numberOfClasses = copyFrom.numberOfClasses;
        this.svIndexMaps = copyFrom.svIndexMaps;
        this.scalingModel = copyFrom.scalingModel;
        this.oneVsOneModels = new SymmetricHashMap2d<L, BinaryModel<L, P>>(copyFrom.oneVsOneModels, excludeLabels);
        this.oneVsAllModels = new HashMap<L, BinaryModel<L, P>>(copyFrom.oneVsAllModels);
        for (Comparable disallowedLabel : excludeLabels) {
            this.oneVsAllModels.remove(disallowedLabel);
        }
    }

    public MultiClassModel(ImmutableSvmParameter param, int numberOfClasses) {
        this.svIndexMaps = new HashMap<BinaryModel<L, P>, int[]>();
        this.numberOfClasses = numberOfClasses;
        this.oneVsOneModels = new SymmetricHashMap2d(numberOfClasses);
        this.oneVsAllModels = new HashMap(numberOfClasses);
        this.oneVsAllThreshold = param.oneVsAllThreshold;
        this.oneVsAllMode = param.oneVsAllMode;
        this.allVsAllMode = param.allVsAllMode;
        this.minVoteProportion = param.minVoteProportion;
    }

    @NotNull
    public ScalingModel<P> getScalingModel() {
        return this.scalingModel;
    }

    public void setScalingModel(@NotNull ScalingModel<P> scalingModel) {
        this.scalingModel = scalingModel;
    }

    @Override
    public L predictLabel(P x) {
        return (L)((Comparable)this.predictLabelWithQuality(x).getBestLabel());
    }

    public L bestProbabilityLabel(Map<L, Float> labelProbabilities) {
        Float bestProb = Float.valueOf(0.0f);
        Comparable bestLabel = null;
        for (Map.Entry<L, Float> entry : labelProbabilities.entrySet()) {
            if (!(entry.getValue().floatValue() > bestProb.floatValue())) continue;
            bestLabel = (Comparable)entry.getKey();
            bestProb = entry.getValue();
        }
        return (L)bestLabel;
    }

    @NotNull
    public VotingResult<L> predictLabelWithQuality(P x) {
        Map<L, Float> oneVsAllProbabilities;
        final P scaledX = this.scalingModel.scaledCopy(x);
        Comparable bestLabel = null;
        float bestOneClassProbability = 0.0f;
        float secondBestOneClassProbability = 0.0f;
        float bestOneVsAllProbability = 0.0f;
        float secondBestOneVsAllProbability = 0.0f;
        ConcurrentMap<KernelFunction<P>, float[]> kValuesPerKernel = new MapMaker().makeComputingMap(new Function<KernelFunction<P>, float[]>(){

            @Override
            public float[] apply(@NotNull KernelFunction<P> kernel) {
                float[] kvalues = new float[MultiClassModel.this.allSVs.length];
                int i = 0;
                for (Object sv : MultiClassModel.this.allSVs) {
                    kvalues[i] = (float)kernel.evaluate(scaledX, sv);
                    ++i;
                }
                return kvalues;
            }
        });
        Map<L, Float> map = oneVsAllProbabilities = this.oneVsAllMode == OneVsAllMode.None ? null : this.computeOneVsAllProbabilities(kValuesPerKernel);
        if ((this.oneVsAllMode == OneVsAllMode.Veto || this.oneVsAllMode == OneVsAllMode.VetoAndBreakTies || this.oneVsAllMode == OneVsAllMode.Best) && oneVsAllProbabilities.isEmpty()) {
            return new VotingResult();
        }
        if (this.oneVsAllMode == OneVsAllMode.Best) {
            for (Map.Entry<L, Float> entry : oneVsAllProbabilities.entrySet()) {
                if (!(entry.getValue().floatValue() > bestOneVsAllProbability)) continue;
                secondBestOneVsAllProbability = bestOneVsAllProbability;
                bestLabel = (Comparable)entry.getKey();
                bestOneVsAllProbability = entry.getValue().floatValue();
            }
            return new VotingResult<Object>(bestLabel, 0.0f, 0.0f, bestOneClassProbability, secondBestOneClassProbability, bestOneVsAllProbability, secondBestOneVsAllProbability);
        }
        int numLabels = this.oneVsOneModels.keySet().size();
        HashMultiset<Object> votes = HashMultiset.create();
        if (this.allVsAllMode == AllVsAllMode.AllVsAll) {
            logger.debug("Sample voting using all pairs of " + numLabels + " labels (" + ((double)(numLabels * (numLabels - 1)) / 2.0 - (double)numLabels) + " models)");
            for (BinaryModel<L, P> binaryModel : this.oneVsOneModels.values()) {
                float[] kvalues = (float[])kValuesPerKernel.get(binaryModel.param.kernel);
                votes.add(binaryModel.predictLabel(kvalues, this.svIndexMaps.get(binaryModel)));
            }
        } else {
            int numActive;
            Set<L> activeClasses = oneVsAllProbabilities != null ? oneVsAllProbabilities.keySet() : this.oneVsOneModels.keySet();
            int requiredActive = this.allVsAllMode == AllVsAllMode.FilteredVsAll ? 1 : 2;
            int n = numActive = oneVsAllProbabilities != null ? oneVsAllProbabilities.size() : numLabels;
            if (requiredActive == 1) {
                logger.debug("Sample voting with all " + numLabels + " vs. " + numActive + " active labels (" + ((double)(numLabels * (numActive - 1)) / 2.0 - (double)numActive) + " models)");
            } else {
                logger.debug("Sample voting using pairs of only " + numActive + " active labels (" + ((double)(numActive * (numActive - 1)) / 2.0 - (double)numActive) + " models)");
            }
            for (BinaryModel binaryModel : this.oneVsOneModels.values()) {
                int activeCount = (activeClasses.contains(binaryModel.getTrueLabel()) ? 1 : 0) + (activeClasses.contains(binaryModel.getFalseLabel()) ? 1 : 0);
                if (activeCount < requiredActive) continue;
                votes.add(binaryModel.predictLabel((Object)scaledX));
            }
        }
        int bestCount = 0;
        int secondBestCount = 0;
        int countSum = 0;
        for (Comparable comparable : votes.elementSet()) {
            int count = votes.count(comparable);
            countSum += count;
            Float oneVsAll = Float.valueOf(1.0f);
            if (this.oneVsAllMode == OneVsAllMode.Veto || this.oneVsAllMode == OneVsAllMode.VetoAndBreakTies) {
                oneVsAll = oneVsAllProbabilities.get(comparable);
                oneVsAll = Float.valueOf(oneVsAll == null ? 0.0f : oneVsAll.floatValue());
            }
            if (count <= bestCount && (count != bestCount || !(oneVsAll.floatValue() > bestOneVsAllProbability))) continue;
            secondBestCount = bestCount;
            secondBestOneVsAllProbability = bestOneVsAllProbability;
            bestLabel = comparable;
            bestCount = count;
            bestOneVsAllProbability = oneVsAll.floatValue();
        }
        double bestVoteProportion = (double)bestCount / (double)countSum;
        double secondBestVoteProportion = (double)secondBestCount / (double)countSum;
        if (bestVoteProportion < this.minVoteProportion) {
            return new VotingResult();
        }
        if ((this.oneVsAllMode == OneVsAllMode.VetoAndBreakTies || this.oneVsAllMode == OneVsAllMode.Veto) && (double)bestOneVsAllProbability < this.oneVsAllThreshold) {
            return new VotingResult();
        }
        return new VotingResult<Comparable>(bestLabel, (float)bestVoteProportion, (float)secondBestVoteProportion, bestOneClassProbability, secondBestOneClassProbability, bestOneVsAllProbability, secondBestOneVsAllProbability);
    }

    public Map<L, Float> computeOneVsAllProbabilities(Map<KernelFunction<P>, float[]> kValuesPerKernel) {
        HashMap<L, Float> oneVsAllProbabilities = new HashMap<L, Float>();
        for (BinaryModel<L, P> binaryModel : this.oneVsAllModels.values()) {
            float[] kvalues;
            float probability = binaryModel.getTrueProbability(kvalues = kValuesPerKernel.get(binaryModel.param.kernel), this.svIndexMaps.get(binaryModel));
            if (!((double)probability >= this.oneVsAllThreshold)) continue;
            oneVsAllProbabilities.put(binaryModel.getTrueLabel(), Float.valueOf(probability));
        }
        return oneVsAllProbabilities;
    }

    public Map<L, Float> predictProbability(P x) {
        if (!this.supportsOneVsOneProbability()) {
            throw new SvmException("Can't make probability predictions");
        }
        float minimumProbability = 1.0E-7f;
        float[][] pairwiseProbabilities = new float[this.numberOfClasses][this.numberOfClasses];
        ArrayList<L> labels = new ArrayList<L>(this.oneVsOneModels.keySet());
        for (int i = 0; i < this.numberOfClasses; ++i) {
            Comparable label1 = (Comparable)labels.get(i);
            for (int j = i + 1; j < this.numberOfClasses; ++j) {
                Comparable label2 = (Comparable)labels.get(j);
                BinaryModel<L, P> binaryModel = this.oneVsOneModels.get(label1, label2);
                if (binaryModel == null) {
                    pairwiseProbabilities[i][j] = 0.0f;
                    pairwiseProbabilities[j][i] = 0.0f;
                    continue;
                }
                float prob = binaryModel.crossValidationResults.getSigmoid().predict(binaryModel.predictValue(x).floatValue());
                pairwiseProbabilities[i][j] = Math.min(Math.max(prob, minimumProbability), 1.0f - minimumProbability);
                pairwiseProbabilities[j][i] = 1.0f - pairwiseProbabilities[i][j];
            }
        }
        float[] probabilityEstimates = this.multiclassProbability(this.numberOfClasses, pairwiseProbabilities);
        HashMap<Comparable, Float> result = new HashMap<Comparable, Float>();
        int i = 0;
        for (Comparable label : labels) {
            result.put(label, Float.valueOf(probabilityEstimates[i]));
            ++i;
        }
        return result;
    }

    public boolean supportsOneVsOneProbability() {
        return this.oneVsOneModels.valueIterator().next().crossValidationResults != null;
    }

    private float[] multiclassProbability(int k, float[][] r) {
        int j;
        int t;
        float[] p = new float[k];
        int iter = 0;
        int maximumIterations = Math.max(100, k);
        float[][] Q = new float[k][k];
        float[] Qp = new float[k];
        float eps = 0.005f / (float)k;
        for (t = 0; t < k; ++t) {
            p[t] = 1.0f / (float)k;
            Q[t][t] = 0.0f;
            for (j = 0; j < t; ++j) {
                float[] fArray = Q[t];
                int n = t;
                fArray[n] = fArray[n] + r[j][t] * r[j][t];
                Q[t][j] = Q[j][t];
            }
            for (j = t + 1; j < k; ++j) {
                float[] fArray = Q[t];
                int n = t;
                fArray[n] = fArray[n] + r[j][t] * r[j][t];
                Q[t][j] = -r[j][t] * r[t][j];
            }
        }
        for (iter = 0; iter < maximumIterations; ++iter) {
            float pQp = 0.0f;
            for (t = 0; t < k; ++t) {
                Qp[t] = 0.0f;
                for (j = 0; j < k; ++j) {
                    int n = t;
                    Qp[n] = Qp[n] + Q[t][j] * p[j];
                }
                pQp += p[t] * Qp[t];
            }
            float maxError = 0.0f;
            for (t = 0; t < k; ++t) {
                float error = Math.abs(Qp[t] - pQp);
                if (!(error > maxError)) continue;
                maxError = error;
            }
            if (maxError < eps) break;
            for (t = 0; t < k; ++t) {
                float diff = (-Qp[t] + pQp) / Q[t][t];
                int n = t;
                p[n] = p[n] + diff;
                pQp = (pQp + diff * (diff * Q[t][t] + 2.0f * Qp[t])) / (1.0f + diff) / (1.0f + diff);
                j = 0;
                while (j < k) {
                    Qp[j] = (Qp[j] + diff * Q[t][j]) / (1.0f + diff);
                    int n2 = j++;
                    p[n2] = p[n2] / (1.0f + diff);
                }
            }
        }
        if (iter >= maximumIterations) {
            logger.error("Multiclass probability attempted too many iterations");
        }
        return p;
    }

    public void prepareModelSvMaps() {
        Integer allSVsIndex;
        int i;
        int[] svIndexMap;
        int totalSVs = 0;
        HashMap<Object, Integer> allSVsMap = new HashMap<Object, Integer>();
        for (BinaryModel<L, P> binaryModel : this.oneVsAllModels.values()) {
            svIndexMap = new int[binaryModel.SVs.length];
            i = 0;
            for (Object p : binaryModel.SVs) {
                allSVsIndex = (Integer)allSVsMap.get(p);
                if (allSVsIndex == null) {
                    allSVsIndex = totalSVs;
                    allSVsMap.put(p, allSVsIndex);
                    ++totalSVs;
                }
                svIndexMap[i] = allSVsIndex;
                ++i;
            }
            this.svIndexMaps.put(binaryModel, svIndexMap);
        }
        for (BinaryModel<L, P> binaryModel : this.oneVsOneModels.values()) {
            svIndexMap = new int[binaryModel.SVs.length];
            i = 0;
            for (Object p : binaryModel.SVs) {
                allSVsIndex = (Integer)allSVsMap.get(p);
                if (allSVsIndex == null) {
                    allSVsIndex = totalSVs;
                    allSVsMap.put(p, allSVsIndex);
                    ++totalSVs;
                }
                svIndexMap[i] = allSVsIndex;
                ++i;
            }
            this.svIndexMaps.put(binaryModel, svIndexMap);
        }
        this.allSVs = new Object[totalSVs];
        for (Map.Entry entry : allSVsMap.entrySet()) {
            this.allSVs[((Integer)entry.getValue()).intValue()] = entry.getKey();
        }
    }

    public synchronized void putOneVsAllModel(L label1, BinaryModel<L, P> binaryModel) {
        this.oneVsAllModels.put(label1, binaryModel);
    }

    public synchronized void putOneVsOneModel(L label1, L label2, BinaryModel<L, P> binaryModel) {
        this.oneVsOneModels.put(label1, label2, binaryModel);
    }

    @Override
    protected void readSupportVectors(BufferedReader fp) {
        throw new UnsupportedOperationException();
    }

    protected void writeSupportVectors(DataOutputStream fp) throws IOException {
        fp.writeBytes("SV\n");
        fp.writeBytes("Saving multi-class support vectors is not implemented yet");
    }

    @Override
    public void writeToStream(DataOutputStream fp) throws IOException {
        throw new NotImplementedException();
    }

    public String getInfo() {
        HashMultiset kernels;
        HashMultiset<Float> cs;
        if (this.crossValidationResults != null) {
            return this.crossValidationResults.getInfo();
        }
        StringBuffer result = new StringBuffer();
        if (this.oneVsAllMode != OneVsAllMode.None) {
            cs = HashMultiset.create();
            kernels = HashMultiset.create();
            for (BinaryModel<L, P> binaryModel : this.oneVsAllModels.values()) {
                cs.add(Float.valueOf(binaryModel.param.C));
                kernels.add(binaryModel.param.kernel);
            }
            result.append("OneVsAll:C=" + cs + "; gamma=" + kernels + "   ");
        }
        if (this.allVsAllMode != AllVsAllMode.None) {
            cs = HashMultiset.create();
            kernels = HashMultiset.create();
            for (BinaryModel<L, P> binaryModel : this.oneVsOneModels.values()) {
                cs.add(Float.valueOf(binaryModel.param.C));
                kernels.add(binaryModel.param.kernel);
            }
            result.append("AllVsAll:C=" + cs + "; gamma=" + kernels + "   ");
        }
        return result.toString();
    }

    @Override
    public Collection<L> getLabels() {
        if (this.oneVsOneModels != null && !this.oneVsOneModels.isEmpty()) {
            return this.oneVsOneModels.values().iterator().next().getLabels();
        }
        if (this.oneVsAllModels != null && !this.oneVsAllModels.isEmpty()) {
            return this.oneVsAllModels.values().iterator().next().getLabels();
        }
        throw new SvmException("Can't get labels from a MultiClassModel with no subsidiary BinaryModels");
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class SymmetricHashMap2d<K extends Comparable, V> {
        HashMap<K, Map<K, V>> l1Map;
        private int sizePerDimension;

        public boolean isEmpty() {
            return this.l1Map.isEmpty();
        }

        public SymmetricHashMap2d(SymmetricHashMap2d<K, V> copyFrom, Collection<K> excludeKeys) {
            this(copyFrom.sizePerDimension);
            for (Map.Entry<K, Map<K, V>> entry : copyFrom.l1Map.entrySet()) {
                Comparable k1 = (Comparable)entry.getKey();
                if (excludeKeys.contains(k1)) continue;
                HashMap<Comparable, V> l2Map = new HashMap<Comparable, V>(this.sizePerDimension);
                for (Map.Entry<K, V> entry2 : entry.getValue().entrySet()) {
                    Comparable k2 = (Comparable)entry2.getKey();
                    if (excludeKeys.contains(k2)) continue;
                    l2Map.put(k2, entry2.getValue());
                }
                this.l1Map.put(k1, l2Map);
            }
        }

        public SymmetricHashMap2d(int sizePerDimension) {
            this.sizePerDimension = sizePerDimension;
            this.l1Map = new HashMap(sizePerDimension);
        }

        V get(K k1, K k2) {
            Map<K, V> l2Map;
            if (k1.compareTo(k2) > 0) {
                K k3 = k1;
                k1 = k2;
                k2 = k3;
            }
            if ((l2Map = this.l1Map.get(k1)) == null) {
                l2Map = new HashMap(this.sizePerDimension);
                this.l1Map.put(k1, l2Map);
            }
            return l2Map.get(k2);
        }

        public Set<K> keySet() {
            return this.l1Map.keySet();
        }

        public void put(K k1, K k2, V value) {
            Map<K, V> l2Map;
            if (k1.compareTo(k2) > 0) {
                K k3 = k1;
                k1 = k2;
                k2 = k3;
            }
            if ((l2Map = this.l1Map.get(k1)) == null) {
                l2Map = new HashMap();
                this.l1Map.put(k1, l2Map);
            }
            l2Map.put(k2, value);
        }

        public Iterable<V> values() {
            return new Iterable<V>(){

                @Override
                public Iterator<V> iterator() {
                    return SymmetricHashMap2d.this.valueIterator();
                }
            };
        }

        public Iterator<V> valueIterator() {
            return new Iterator<V>(){
                Iterator<K> k1iter;
                Iterator<V> l2iter;
                {
                    this.k1iter = SymmetricHashMap2d.this.l1Map.keySet().iterator();
                    this.l2iter = null;
                }

                @Override
                public boolean hasNext() {
                    return this.l2iter != null && this.l2iter.hasNext() || this.k1iter.hasNext();
                }

                @Override
                public V next() {
                    if (this.l2iter == null || !this.l2iter.hasNext()) {
                        if (this.k1iter.hasNext()) {
                            this.l2iter = SymmetricHashMap2d.this.l1Map.get(this.k1iter.next()).values().iterator();
                        } else {
                            return null;
                        }
                    }
                    return this.l2iter.next();
                }

                @Override
                public void remove() {
                    throw new UnsupportedOperationException();
                }
            };
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static enum AllVsAllMode {
        None,
        AllVsAll,
        FilteredVsAll,
        FilteredVsFiltered;

    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static enum OneVsAllMode {
        None,
        Best,
        Veto,
        BreakTies,
        VetoAndBreakTies;

    }
}

