/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.jlibsvm.binary;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterGrid;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.SVM;
import edu.berkeley.compbio.jlibsvm.SvmException;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationProblem;
import edu.berkeley.compbio.jlibsvm.binary.BinaryModel;
import edu.berkeley.compbio.jlibsvm.binary.SvmBinaryCrossValidationResults;
import java.util.Map;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public abstract class BinaryClassificationSVM<L extends Comparable, P>
extends SVM<L, P, BinaryClassificationProblem<L, P>> {
    private static final Logger logger = Logger.getLogger(BinaryClassificationSVM.class);

    @Override
    public BinaryModel<L, P> train(@NotNull BinaryClassificationProblem<L, P> problem, @NotNull ImmutableSvmParameter<L, P> param) {
        this.validateParam(param);
        BinaryModel<L, P> result = param instanceof ImmutableSvmParameterGrid ? this.trainGrid(problem, (ImmutableSvmParameterGrid)param) : (param.probability ? this.trainScaledWithCV(problem, (ImmutableSvmParameterPoint)param) : this.trainScaled(problem, (ImmutableSvmParameterPoint)param));
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private BinaryModel<L, P> trainGrid(final @NotNull BinaryClassificationProblem<L, P> problem, @NotNull ImmutableSvmParameterGrid<L, P> param) {
        final GridTrainingResult gtresult = new GridTrainingResult();
        Parallel.forEach(param.getGridParams(), new Function<ImmutableSvmParameterPoint<L, P>, Void>(){

            @Override
            public Void apply(ImmutableSvmParameterPoint<L, P> gridParam) {
                SvmBinaryCrossValidationResults crossValidationResults = BinaryClassificationSVM.this.performCrossValidation(problem, gridParam);
                gtresult.update(gridParam, crossValidationResults);
                return null;
            }
        });
        logger.info("Chose grid point: " + gtresult.bestParam);
        BinaryModel<L, P> result = this.trainScaled(problem, gtresult.bestParam);
        GridTrainingResult gridTrainingResult = gtresult;
        synchronized (gridTrainingResult) {
            result.crossValidationResults = gtresult.bestCrossValidationResults;
        }
        return result;
    }

    private BinaryModel<L, P> trainScaledWithCV(@NotNull BinaryClassificationProblem<L, P> problem, @NotNull ImmutableSvmParameterPoint<L, P> param) {
        SvmBinaryCrossValidationResults<L, P> cv = null;
        try {
            cv = this.performCrossValidation(problem, (ImmutableSvmParameter<L, P>)param);
        }
        catch (SvmException e) {
            logger.debug("Could not perform cross-validation", e);
        }
        BinaryModel<L, P> result = this.trainScaled(problem, param);
        result.crossValidationResults = cv;
        result.printSolutionInfo(problem);
        return result;
    }

    public SvmBinaryCrossValidationResults<L, P> performCrossValidation(@NotNull BinaryClassificationProblem<L, P> problem, @NotNull ImmutableSvmParameter<L, P> param) {
        ImmutableSvmParameterPoint noProbParam = (ImmutableSvmParameterPoint)param.noProbabilityCopy();
        Map decisionValues = this.continuousCrossValidation(problem, noProbParam);
        SvmBinaryCrossValidationResults<L, P> cv = new SvmBinaryCrossValidationResults<L, P>(problem, decisionValues, param.probability);
        return cv;
    }

    protected abstract BinaryModel<L, P> trainOne(@NotNull BinaryClassificationProblem<L, P> var1, float var2, float var3, @NotNull ImmutableSvmParameterPoint<L, P> var4);

    private BinaryModel<L, P> trainScaled(@NotNull BinaryClassificationProblem<L, P> problem, @NotNull ImmutableSvmParameterPoint<L, P> param) {
        if (param.scalingModelLearner != null && param.scaleBinaryMachinesIndependently) {
            problem = problem.getScaledCopy(param.scalingModelLearner);
        }
        BinaryModel<L, P> result = this.trainWeighted(problem, param);
        result.printSolutionInfo(problem);
        return result;
    }

    private BinaryModel<L, P> trainWeighted(@NotNull BinaryClassificationProblem<L, P> problem, @NotNull ImmutableSvmParameterPoint<L, P> param) {
        float weightedCp = param.C;
        float weightedCn = param.C;
        if (param.redistributeUnbalancedC) {
            Float weightN;
            Float weightP = param.getWeight(problem.getTrueLabel());
            if (weightP != null) {
                weightedCp *= weightP.floatValue();
            }
            if ((weightN = param.getWeight(problem.getFalseLabel())) != null) {
                weightedCn *= weightN.floatValue();
            }
        }
        BinaryModel<L, P> result = this.trainOne(problem, weightedCp, weightedCn, param);
        return result;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class GridTrainingResult {
        ImmutableSvmParameterPoint<L, P> bestParam = null;
        SvmBinaryCrossValidationResults<L, P> bestCrossValidationResults = null;
        float bestSensitivity = -1.0f;

        private GridTrainingResult() {
        }

        synchronized void update(ImmutableSvmParameterPoint<L, P> gridParam, SvmBinaryCrossValidationResults<L, P> crossValidationResults) {
            float sensitivity = crossValidationResults.classNormalizedSensitivity();
            if (sensitivity > this.bestSensitivity) {
                this.bestParam = gridParam;
                this.bestSensitivity = sensitivity;
                this.bestCrossValidationResults = crossValidationResults;
            }
        }
    }
}

