/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.ml.cluster.kmeans;

import com.davidsoergel.stats.DissimilarityMeasure;
import edu.berkeley.compbio.ml.cluster.AbstractUnsupervisedOnlineClusteringMethod;
import edu.berkeley.compbio.ml.cluster.AdditiveCentroidCluster;
import edu.berkeley.compbio.ml.cluster.AdditiveClusterable;
import edu.berkeley.compbio.ml.cluster.CentroidCluster;
import edu.berkeley.compbio.ml.cluster.CentroidClusteringMethod;
import edu.berkeley.compbio.ml.cluster.CentroidClusteringUtils;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import edu.berkeley.compbio.ml.cluster.SampleInitializedOnlineClusteringMethod;
import edu.berkeley.compbio.ml.cluster.SemisupervisedClusteringMethod;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KmeansClustering<T extends AdditiveClusterable<T>>
extends AbstractUnsupervisedOnlineClusteringMethod<T, CentroidCluster<T>>
implements SemisupervisedClusteringMethod<T>,
CentroidClusteringMethod<T>,
SampleInitializedOnlineClusteringMethod<T> {
    private static final Logger logger = Logger.getLogger(KmeansClustering.class);

    public KmeansClustering(DissimilarityMeasure<T> dm, Set<String> potentialTrainingBins, Map<String, Set<String>> predictLabelSets, ProhibitionModel<T> prohibitionModel, Set<String> testLabels) {
        super(dm, potentialTrainingBins, predictLabelSets, prohibitionModel, testLabels);
    }

    @Override
    public String shortClusteringStats() {
        return CentroidClusteringUtils.shortClusteringStats(this.getClusters(), this.measure);
    }

    @Override
    public void computeClusterStdDevs(ClusterableIterator<T> theDataPointProvider) {
        CentroidClusteringUtils.computeClusterStdDevs(this.getClusters(), this.measure, this.getAssignments(), theDataPointProvider);
    }

    @Override
    public String clusteringStats() {
        ByteArrayOutputStream b = new ByteArrayOutputStream();
        CentroidClusteringUtils.writeClusteringStatsToStream(this.getClusters(), this.measure, b);
        return b.toString();
    }

    @Override
    public void writeClusteringStatsToStream(OutputStream outf) {
        CentroidClusteringUtils.writeClusteringStatsToStream(this.getClusters(), this.measure, outf);
    }

    @Override
    public boolean add(T p) {
        assert (p != null);
        String id = p.getId();
        ClusterMove<T, CentroidCluster<T>> cm = this.bestClusterMove(p);
        if (cm.isChanged()) {
            try {
                ((CentroidCluster)cm.oldCluster).remove(p);
            }
            catch (NullPointerException e) {
                // empty catch block
            }
            ((CentroidCluster)cm.bestCluster).add(p);
            this.putAssignment(id, cm.bestCluster);
            return true;
        }
        return false;
    }

    @Override
    public void initializeWithSamples(ClusterableIterator<T> trainingIterator, int initsamples) {
        for (int i = 0; i < initsamples; ++i) {
            AdditiveCentroidCluster<AdditiveClusterable> c = new AdditiveCentroidCluster<AdditiveClusterable>(i, (AdditiveClusterable)trainingIterator.nextFullyLabelled());
            this.addCluster(c);
        }
        logger.debug("initialized " + initsamples + " clusters");
    }

    @Override
    public ClusterMove<T, CentroidCluster<T>> bestClusterMove(T p) {
        ClusterMove result = new ClusterMove();
        String id = p.getId();
        result.oldCluster = this.getAssignment(id);
        if (logger.isTraceEnabled()) {
            logger.trace("Choosing best cluster for " + p + " (previous = " + result.oldCluster + ")");
        }
        for (CentroidCluster c : this.getClusters()) {
            double d = this.measure.distanceFromTo(p, c.getCentroid());
            if (logger.isTraceEnabled()) {
                logger.trace("Trying " + c + "; distance = " + d + "; best so far = " + result.bestDistance);
            }
            if (d < result.bestDistance) {
                result.secondBestDistance = result.bestDistance;
                result.bestDistance = d;
                result.bestCluster = c;
                continue;
            }
            if (!(d < result.secondBestDistance)) continue;
            result.secondBestDistance = d;
        }
        if (logger.isTraceEnabled()) {
            logger.trace("Chose " + result.bestCluster);
        }
        if (result.bestCluster == null) {
            logger.warn("Can't classify: " + p);
        }
        return result;
    }
}

