/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.ml.cluster.bayesian;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import com.davidsoergel.dsutils.collections.ConcurrentHashWeightedSet;
import com.davidsoergel.dsutils.collections.MutableWeightedSet;
import com.davidsoergel.dsutils.collections.WeightedSet;
import com.davidsoergel.stats.DissimilarityMeasure;
import com.davidsoergel.stats.ProbabilisticDissimilarityMeasure;
import com.google.common.collect.TreeMultimap;
import edu.berkeley.compbio.ml.cluster.AbstractSupervisedOnlineClusteringMethod;
import edu.berkeley.compbio.ml.cluster.AdditiveClusterable;
import edu.berkeley.compbio.ml.cluster.BasicCentroidCluster;
import edu.berkeley.compbio.ml.cluster.CentroidCluster;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.ClusterRuntimeException;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.NoGoodClusterException;
import edu.berkeley.compbio.ml.cluster.PointClusterFilter;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.Nullable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public abstract class MultiNeighborClustering<T extends AdditiveClusterable<T>>
extends AbstractSupervisedOnlineClusteringMethod<T, CentroidCluster<T>> {
    private static final Logger logger = Logger.getLogger(MultiNeighborClustering.class);
    protected final int maxNeighbors;
    protected final double unknownDistanceThreshold;

    public MultiNeighborClustering(DissimilarityMeasure<T> dm, double unknownDistanceThreshold, Set<String> potentialTrainingBins, Map<String, Set<String>> predictLabelSets, ProhibitionModel<T> prohibitionModel, Set<String> testLabels, int maxNeighbors) {
        super(dm, potentialTrainingBins, predictLabelSets, prohibitionModel, testLabels);
        this.maxNeighbors = maxNeighbors;
        this.unknownDistanceThreshold = unknownDistanceThreshold;
    }

    @Override
    public String bestLabel(T sample, Set<String> predictLabels) throws NoGoodClusterException {
        TreeMultimap<Double, ClusterMove<T, CentroidCluster<T>>> moves = this.scoredClusterMoves(sample);
        VotingResults votingResults = this.addUpNeighborVotes(moves);
        return votingResults.getSubResults(predictLabels).getBestLabel();
    }

    @Override
    public void trainWithKnownTrainingLabels(ClusterableIterator<T> trainingIterator) {
        final AtomicInteger i = new AtomicInteger(0);
        Parallel.forEach(trainingIterator, new Function<T, Void>(){

            @Override
            public Void apply(@Nullable T point2) {
                int clusterId = i.incrementAndGet();
                BasicCentroidCluster cluster = new BasicCentroidCluster(clusterId, point2);
                MultiNeighborClustering.this.addCluster(cluster);
                if (clusterId % 1000 == 0) {
                    logger.info("Trained " + clusterId + " samples");
                }
                return null;
            }
        });
        logger.info("Done training " + this.getNumClusters() + " samples");
    }

    protected VotingResults addUpNeighborVotes(TreeMultimap<Double, ClusterMove<T, CentroidCluster<T>>> moves) {
        VotingResults result = new VotingResults();
        int neighborsCounted = 0;
        double lastDistance = 0.0;
        for (ClusterMove cm : moves.values()) {
            if (neighborsCounted >= this.maxNeighbors) break;
            assert (cm.bestDistance >= lastDistance);
            lastDistance = cm.bestDistance;
            WeightedSet<String> labelsOnThisCluster = ((CentroidCluster)cm.bestCluster).getDerivedLabelProbabilities();
            result.addVotes(labelsOnThisCluster, cm.voteWeight);
            for (Map.Entry<String, Double> entry : labelsOnThisCluster.getItemNormalizedMap().entrySet()) {
                String label = entry.getKey();
                Double labelProbability = entry.getValue();
                result.addContribution(cm, label, labelProbability);
            }
            ++neighborsCounted;
        }
        return result;
    }

    @Override
    public ClusterMove<T, CentroidCluster<T>> bestClusterMove(T p) throws NoGoodClusterException {
        throw new ClusterRuntimeException("It doesn't make sense to get the best clustermove with a multi-neighbor clustering; look for the best label instead using scoredClusterMoves");
    }

    protected TreeMultimap<Double, ClusterMove<T, CentroidCluster<T>>> scoredClusterMoves(T p) throws NoGoodClusterException {
        TreeMultimap<Double, ClusterMove<T, CentroidCluster<T>>> result = TreeMultimap.create();
        PointClusterFilter<T> clusterFilter = this.prohibitionModel == null ? null : this.prohibitionModel.getFilter(p);
        for (CentroidCluster cluster : this.getClusters()) {
            if (clusterFilter != null && clusterFilter.isProhibited(cluster)) continue;
            double distance = this.measure instanceof ProbabilisticDissimilarityMeasure ? ((ProbabilisticDissimilarityMeasure)this.measure).distanceFromTo(p, cluster.getCentroid(), (Double)this.clusterPriors.get(cluster)) : this.measure.distanceFromTo(p, cluster.getCentroid());
            ClusterMove<T, CentroidCluster<T>> cm = this.makeClusterMove(cluster, distance);
            if (!(cm.bestDistance < this.unknownDistanceThreshold)) continue;
            result.put(cm.bestDistance, cm);
        }
        if (result.isEmpty()) {
            throw new NoGoodClusterException("No clusters passed the unknown threshold");
        }
        return result;
    }

    protected ClusterMove<T, CentroidCluster<T>> makeClusterMove(CentroidCluster<T> cluster, double distance) {
        ClusterMove cm = new ClusterMove();
        cm.bestCluster = cluster;
        cm.bestDistance = distance;
        return cm;
    }

    public static class BestLabelPair {
        final String bestLabel;
        final String secondBestLabel;

        private BestLabelPair(String bestLabel, String secondBestLabel) {
            this.bestLabel = bestLabel;
            this.secondBestLabel = secondBestLabel;
        }

        public String getBestLabel() {
            return this.bestLabel;
        }

        public String getSecondBestLabel() {
            return this.secondBestLabel;
        }

        public boolean hasSecondBestLabel() {
            return this.secondBestLabel != null;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    protected class VotingResults {
        private final Map<String, MutableWeightedSet<ClusterMove<T, CentroidCluster<T>>>> labelContributions = new HashMap();
        private final MutableWeightedSet<String> labelVotes = new ConcurrentHashWeightedSet<String>();

        protected VotingResults() {
        }

        public void addContribution(ClusterMove<T, CentroidCluster<T>> cm, String label, Double labelProbability) {
            MutableWeightedSet contributionSet = this.labelContributions.get(label);
            if (contributionSet == null) {
                contributionSet = new ConcurrentHashWeightedSet();
                this.labelContributions.put(label, contributionSet);
            }
            contributionSet.add(cm, labelProbability, 1);
        }

        public void addVotes(WeightedSet<String> labelsOnThisCluster) {
            this.labelVotes.addAll(labelsOnThisCluster);
        }

        public void addVotes(WeightedSet<String> labelsOnThisCluster, double multiplier) {
            this.labelVotes.addAll(labelsOnThisCluster, multiplier);
        }

        public double computeWeightedDistance(String label) {
            return this.computeWeightedDistance(this.labelContributions.get(label));
        }

        public BestLabelPair getSubResults(Set<String> populatedTrainingLabels) throws NoGoodClusterException {
            String bestLabel;
            Comparator weightedDistanceSort = new Comparator(){
                final Map<String, Double> cache = new HashMap<String, Double>();

                private Double getWeightedDistance(String label) {
                    Double result = this.cache.get(label);
                    if (result == null) {
                        result = VotingResults.this.computeWeightedDistance((WeightedSet)VotingResults.this.labelContributions.get(label));
                        this.cache.put(label, result);
                    }
                    return result;
                }

                public int compare(Object o1, Object o2) {
                    return Double.compare(this.getWeightedDistance((String)o1), this.getWeightedDistance((String)o2));
                }
            };
            WeightedSet<String> subVotes = this.labelVotes.extractWithKeys(populatedTrainingLabels);
            Iterator vi = subVotes.keysInDecreasingWeightOrder(weightedDistanceSort).iterator();
            try {
                bestLabel = (String)vi.next();
            }
            catch (NoSuchElementException e) {
                throw new NoGoodClusterException();
            }
            String secondBestLabel = null;
            try {
                secondBestLabel = (String)vi.next();
            }
            catch (NoSuchElementException e) {
                // empty catch block
            }
            return new BestLabelPair(bestLabel, secondBestLabel);
        }

        private double computeWeightedDistance(WeightedSet<ClusterMove<T, CentroidCluster<T>>> dominantLabelContributions) {
            double weightedComputedDistance = 0.0;
            for (Map.Entry entry : dominantLabelContributions.getItemNormalizedMap().entrySet()) {
                ClusterMove contributingCm = entry.getKey();
                Double contributionWeight = entry.getValue();
                weightedComputedDistance += contributionWeight * contributingCm.bestDistance;
            }
            return weightedComputedDistance;
        }

        public WeightedSet<String> getLabelVotes() {
            return this.labelVotes;
        }
    }
}

