/*
 * Decompiled with CFR 0.152.
 */
package com.davidsoergel.stats;

import com.davidsoergel.dsutils.DSArrayUtils;
import com.davidsoergel.dsutils.math.MathUtils;
import com.davidsoergel.stats.DistributionException;
import com.davidsoergel.stats.MultinomialDistribution;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Multiset;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.commons.collections15.Bag;
import org.apache.commons.lang.ArrayUtils;
import org.jetbrains.annotations.NotNull;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Multinomial<T>
implements Cloneable {
    private MultinomialDistribution dist = new MultinomialDistribution();
    private BiMap<T, Integer> elementIndexes = HashBiMap.create(10);
    private int maxIndex = 0;

    public static <T> Multinomial<T> mixture(Multinomial<T> basis, Multinomial<T> bias, double mixingProportion) throws DistributionException {
        Multinomial<T> result = new Multinomial<T>();
        assert (basis.isAlreadyNormalized());
        assert (bias.isAlreadyNormalized());
        assert (basis.getElements().size() == bias.getElements().size());
        for (T key : basis.getElements()) {
            double p = (1.0 - mixingProportion) * basis.get(key) + mixingProportion * bias.get(key);
            result.put(key, p);
        }
        assert (result.isAlreadyNormalized());
        return result;
    }

    public synchronized boolean isAlreadyNormalized() throws DistributionException {
        return this.dist.isAlreadyNormalized();
    }

    public Multinomial() {
    }

    public Multinomial(T[] keys, Map<T, Double> values) throws DistributionException {
        this();
        for (T k : keys) {
            this.put(k, values.get(k));
        }
        this.normalize();
    }

    public Multinomial(Bag<T> counts) throws DistributionException {
        this();
        for (T k : counts.uniqueSet()) {
            this.put(k, counts.getCount(k));
        }
        this.normalize();
    }

    public Multinomial(Multiset<T> counts) throws DistributionException {
        this();
        for (Multiset.Entry<T> k : counts.entrySet()) {
            this.put(k.getElement(), k.getCount());
        }
        this.normalize();
    }

    public synchronized void put(@NotNull T obj, double prob) throws DistributionException {
        if (this.elementIndexes.containsKey(obj)) {
            this.dist.update((Integer)this.elementIndexes.get(obj), prob);
        } else {
            this.elementIndexes.put(obj, this.maxIndex);
            ++this.maxIndex;
            this.dist.add(prob);
        }
    }

    public synchronized void normalize() throws DistributionException {
        this.dist.normalize();
    }

    public synchronized Collection<T> getElements() {
        return this.elementIndexes.keySet();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public synchronized Multinomial<T> clone() {
        Multinomial<T> result;
        Multinomial<T> multinomial = result = new Multinomial<T>();
        synchronized (multinomial) {
            result.dist = new MultinomialDistribution(this.dist);
            result.elementIndexes = HashBiMap.create(this.elementIndexes);
        }
        return result;
    }

    public synchronized double KLDivergenceToThisFrom(Multinomial<T> belief) throws DistributionException {
        double divergence = 0.0;
        for (Object key : this.elementIndexes.keySet()) {
            double p = this.get(key);
            double q = belief.get(key);
            if (p == 0.0 || q == 0.0) {
                throw new DistributionException("Can't compute KL divergence: distributions not smoothed");
            }
            if (!Double.isNaN(divergence += p * MathUtils.approximateLog(p / q))) continue;
            throw new DistributionException("Got NaN when computing KL divergence.");
        }
        return divergence;
    }

    public synchronized double getLog(T obj) throws DistributionException {
        return MathUtils.approximateLog(this.get(obj));
    }

    public synchronized double get(T obj) throws DistributionException {
        Integer i = (Integer)this.elementIndexes.get(obj);
        if (i == null) {
            throw new DistributionException("No probability known: " + obj);
        }
        return this.dist.probs[i];
    }

    public synchronized void mixIn(Multinomial<T> uniform, double smoothFactor) throws DistributionException {
        for (int c = 0; c < this.elementIndexes.size(); ++c) {
            this.dist.probs[c] = this.dist.probs[c] * (1.0 - smoothFactor) + uniform.get(this.elementIndexes.inverse().get(c)) * smoothFactor;
        }
    }

    @NotNull
    public synchronized T sample() throws DistributionException {
        int index = this.dist.sample();
        Object result = this.elementIndexes.inverse().get(index);
        if (result == null) {
            throw new Error("Impossible");
        }
        return (T)result;
    }

    public synchronized int size() {
        return this.elementIndexes.size();
    }

    public synchronized void redistributeWithMinimum(double minimumProbability) throws DistributionException {
        double redistributionProportion = (double)this.maxIndex * minimumProbability;
        if (redistributionProportion > 1.0) {
            throw new DistributionException("Can't use a minimum probability of " + minimumProbability + " for a multinomial with " + this.maxIndex + "elements.");
        }
        for (int c = 0; c < this.maxIndex; ++c) {
            this.dist.probs[c] = (1.0 - redistributionProportion) * this.dist.probs[c] + minimumProbability;
        }
    }

    public synchronized double getDominantProbability() {
        return DSArrayUtils.max(this.dist.probs);
    }

    public synchronized T getDominantKey() {
        return (T)this.elementIndexes.inverse().get(DSArrayUtils.argmax(this.dist.probs));
    }

    public synchronized void remove(T obj) throws DistributionException {
        Integer i = (Integer)this.elementIndexes.get(obj);
        if (i == null) {
            throw new DistributionException("Can't remove nonexistent element: " + obj);
        }
        this.elementIndexes.remove(obj);
        this.dist.probs = ArrayUtils.remove(this.dist.probs, (int)i);
        this.dist.normalize();
        Integer n = i;
        Integer n2 = i = Integer.valueOf(i + 1);
        while (i <= this.elementIndexes.size()) {
            Object t = this.elementIndexes.inverse().get(i);
            this.elementIndexes.put(t, i - 1);
            n2 = i;
            Integer n3 = i = Integer.valueOf(i + 1);
        }
    }

    public synchronized void increment(T obj, double increment) throws DistributionException {
        try {
            double currentval = this.get(obj);
            this.dist.update((Integer)this.elementIndexes.get(obj), currentval + increment);
        }
        catch (DistributionException e) {
            this.elementIndexes.put(obj, this.maxIndex);
            ++this.maxIndex;
            this.dist.add(increment);
        }
    }

    public synchronized Map<T, Double> getValueMap() {
        HashMap result = new HashMap();
        Set entries = this.elementIndexes.entrySet();
        for (Map.Entry entry : entries) {
            Object key = entry.getKey();
            Integer value = (Integer)entry.getValue();
            result.put(key, this.dist.probs[value]);
        }
        return result;
    }
}

