/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.jlibsvm.multi;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import com.davidsoergel.dsutils.collections.UnorderedPair;
import com.davidsoergel.dsutils.collections.UnorderedPairIterator;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterGrid;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.SVM;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationSVM;
import edu.berkeley.compbio.jlibsvm.binary.BinaryModel;
import edu.berkeley.compbio.jlibsvm.binary.BooleanClassificationProblemImpl;
import edu.berkeley.compbio.jlibsvm.labelinverter.LabelInverter;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassProblem;
import edu.berkeley.compbio.jlibsvm.multi.SvmMultiClassCrossValidationResults;
import edu.berkeley.compbio.jlibsvm.util.SubtractionMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultiClassificationSVM<L extends Comparable<L>, P>
extends SVM<L, P, MultiClassProblem<L, P>> {
    private static final Logger logger = Logger.getLogger(MultiClassificationSVM.class);
    private BinaryClassificationSVM<L, P> binarySvm;

    public MultiClassificationSVM(BinaryClassificationSVM<L, P> binarySvm) {
        this.binarySvm = binarySvm;
    }

    @Override
    public String getSvmType() {
        return "multiclass " + this.binarySvm.getSvmType();
    }

    public SvmMultiClassCrossValidationResults<L, P> performCrossValidation(@NotNull MultiClassProblem<L, P> problem, @NotNull ImmutableSvmParameter<L, P> param) {
        Map<P, L> predictions = this.discreteCrossValidation(problem, param);
        SvmMultiClassCrossValidationResults<L, P> cv = new SvmMultiClassCrossValidationResults<L, P>(problem, predictions);
        cv.param = param;
        return cv;
    }

    @Override
    public MultiClassModel<L, P> train(@NotNull MultiClassProblem<L, P> problem, @NotNull ImmutableSvmParameter<L, P> param) {
        this.validateParam(param);
        MultiClassModel<L, P> result = param instanceof ImmutableSvmParameterGrid && !param.gridsearchBinaryMachinesIndependently ? this.trainGrid(problem, (ImmutableSvmParameterGrid)param) : this.trainScaled(problem, param);
        return result;
    }

    public MultiClassModel<L, P> trainGrid(final @NotNull MultiClassProblem<L, P> problem, @NotNull ImmutableSvmParameterGrid<L, P> param) {
        final GridTrainingResult gtresult = new GridTrainingResult();
        HashSet gridTasks = new HashSet();
        Collection<ImmutableSvmParameterPoint<L, P>> parameterPoints = param.getGridParams();
        Parallel.forEach(parameterPoints, new Function<ImmutableSvmParameterPoint<L, P>, Void>(){

            @Override
            public Void apply(ImmutableSvmParameterPoint<L, P> gridParam) {
                SvmMultiClassCrossValidationResults crossValidationResults = MultiClassificationSVM.this.performCrossValidation(problem, gridParam);
                gtresult.update(crossValidationResults);
                return null;
            }
        });
        logger.info("Chose grid point: " + gtresult.bestCrossValidationResults.param);
        MultiClassModel<L, P> result = this.trainScaled(problem, gtresult.bestCrossValidationResults.param);
        result.crossValidationResults = gtresult.bestCrossValidationResults;
        return result;
    }

    public MultiClassModel<L, P> trainScaled(@NotNull MultiClassProblem<L, P> problem, @NotNull ImmutableSvmParameter<L, P> param) {
        if (param.scalingModelLearner != null && !param.scaleBinaryMachinesIndependently) {
            return this.trainWithoutScaling(problem.getScaledCopy(param.scalingModelLearner), param);
        }
        return this.trainWithoutScaling(problem, param);
    }

    private MultiClassModel<L, P> trainWithoutScaling(final @NotNull MultiClassProblem<L, P> problem, final @NotNull ImmutableSvmParameter<L, P> param) {
        int numLabels = problem.getLabels().size();
        final MultiClassModel model = new MultiClassModel(param, numLabels);
        model.setScalingModel(problem.getScalingModel());
        final Map<L, Set<P>> examplesByLabel = problem.getExamplesByLabel();
        if (param.oneVsAllMode != MultiClassModel.OneVsAllMode.None) {
            final ImmutableSvmParameter<L, P> probParam = param.withProbabilityCopy();
            logger.info("Training one-vs-all classifiers for " + numLabels + " labels");
            final LabelInverter<L> labelInverter = problem.getLabelInverter();
            Parallel.forEach(problem.getLabels(), new Function<L, Void>(){

                @Override
                public Void apply(L label) {
                    Comparable notLabel = (Comparable)labelInverter.invert(label);
                    Set labelExamples = (Set)examplesByLabel.get(label);
                    Collection<Map.Entry<Object, Object>> entries = problem.getExamples().entrySet();
                    if (param.falseClassSVlimit != Integer.MAX_VALUE) {
                        ArrayList entryList = new ArrayList(entries);
                        Collections.shuffle(entryList);
                        int toIndex = param.falseClassSVlimit + labelExamples.size();
                        toIndex = Math.min(toIndex, entryList.size());
                        entries = entryList.subList(0, toIndex);
                    }
                    Set notlabelExamples = new SubtractionMap(entries, labelExamples, param.falseClassSVlimit).keySet();
                    BooleanClassificationProblemImpl subProblem = new BooleanClassificationProblemImpl(problem.getLabelClass(), (Comparable)label, labelExamples, notLabel, notlabelExamples, problem.getExampleIds());
                    BinaryModel result = MultiClassificationSVM.this.binarySvm.train(subProblem, probParam);
                    model.putOneVsAllModel(label, result);
                    return null;
                }
            });
        }
        if (param.allVsAllMode != MultiClassModel.AllVsAllMode.None) {
            int numClassifiers = numLabels * (numLabels - 1) / 2;
            HashSet allVsAllTasks = new HashSet(numClassifiers);
            logger.info("Training " + numClassifiers + " one-vs-one classifiers for " + numLabels + " labels");
            boolean c = false;
            UnorderedPairIterator labelPairIterator = new UnorderedPairIterator(problem.getLabels(), problem.getLabels());
            Parallel.forEach(labelPairIterator, new Function<UnorderedPair<L>, Void>(){

                @Override
                public Void apply(UnorderedPair<L> from) {
                    Object label1 = from.getKey1();
                    Object label2 = from.getKey2();
                    Set label1Examples = (Set)examplesByLabel.get(label1);
                    Set label2Examples = (Set)examplesByLabel.get(label2);
                    BooleanClassificationProblemImpl subProblem = new BooleanClassificationProblemImpl(problem.getLabelClass(), label1, label1Examples, label2, label2Examples, problem.getExampleIds());
                    BinaryModel result = MultiClassificationSVM.this.binarySvm.train(subProblem, param);
                    model.putOneVsOneModel(label1, label2, result);
                    return null;
                }
            });
        }
        model.prepareModelSvMaps();
        return model;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class GridTrainingResult {
        SvmMultiClassCrossValidationResults<L, P> bestCrossValidationResults = null;
        float bestSensitivity = -1.0f;

        private GridTrainingResult() {
        }

        synchronized void update(SvmMultiClassCrossValidationResults<L, P> crossValidationResults) {
            float sensitivity = crossValidationResults.classNormalizedSensitivity();
            if (sensitivity > this.bestSensitivity) {
                this.bestSensitivity = sensitivity;
                this.bestCrossValidationResults = crossValidationResults;
            }
        }
    }
}

