/*
 * Decompiled with CFR 0.152.
 */
package com.github.pmerienne.trident.ml.nlp;

import com.github.pmerienne.trident.ml.nlp.TextClassifier;
import com.github.pmerienne.trident.ml.nlp.TextFeaturesExtractor;
import com.github.pmerienne.trident.ml.nlp.Vocabulary;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class KLDClassifier
implements TextClassifier<Integer>,
TextFeaturesExtractor,
Serializable {
    private static final long serialVersionUID = 3869875629653284342L;
    private int maxWordsPerClass = 500000;
    private double thresholdFactor = 10.0;
    private boolean normalize = true;
    private List<Vocabulary> classVocabularies = new ArrayList<Vocabulary>();
    private List<Double> gammas = new ArrayList<Double>();
    private Double espilon = null;

    public KLDClassifier() {
    }

    public KLDClassifier(int nbClasses) {
        this(nbClasses, 500000, 10.0, true);
    }

    public KLDClassifier(int nbClasses, int maxWordsPerClass) {
        this(nbClasses, maxWordsPerClass, 10.0, true);
    }

    public KLDClassifier(int nbClasses, int maxWordsPerClass, double thresholdFactor, boolean normalize) {
        this.maxWordsPerClass = maxWordsPerClass;
        this.thresholdFactor = thresholdFactor;
        this.normalize = normalize;
        for (int i = 0; i < nbClasses; ++i) {
            this.classVocabularies.add(new Vocabulary());
            this.gammas.add(null);
        }
    }

    @Override
    public double[] extractFeatures(List<String> documentWords) {
        Vocabulary documentVocabulary = new Vocabulary(documentWords);
        Set<String> vocabulary = this.createGlobalVocabulary();
        int vocabularySize = vocabulary.size();
        int nbClasses = this.classVocabularies.size();
        double[] features = new double[vocabularySize * nbClasses];
        double beta = this.caculateBeta(documentVocabulary);
        int i = 0;
        for (String word : vocabulary) {
            Double tpd = this.wordProbabilityInDocument(word, documentVocabulary, beta);
            for (int j = 0; j < nbClasses; ++j) {
                Double tpc = this.wordProbabilityInCategory(word, 0);
                features[j * vocabularySize + i] = (tpc - tpd) * Math.log(tpc / tpd);
            }
            ++i;
        }
        return features;
    }

    @Override
    public void update(Integer classIndex, List<String> documentWords) {
        Vocabulary classVocabulary = this.classVocabularies.get(classIndex);
        classVocabulary.addAll(documentWords);
        classVocabulary.limitWords(this.maxWordsPerClass);
        this.gammas.set(classIndex, null);
        this.espilon = null;
    }

    @Override
    public Integer classify(List<String> documentWords) {
        int classIndex = -1;
        double[] distances = this.distance(documentWords);
        double minDistance = Double.POSITIVE_INFINITY;
        int i = 0;
        for (double distance : distances) {
            if (distance < minDistance) {
                minDistance = distance;
                classIndex = i;
            }
            ++i;
        }
        return classIndex;
    }

    public double[] distance(List<String> documentWords) {
        double[] distance = new double[this.classVocabularies.size()];
        Vocabulary documentVocabulary = new Vocabulary(documentWords);
        double beta = this.caculateBeta(documentVocabulary);
        double betaZero = this.caculateBeta(new Vocabulary());
        for (int classIndex = 0; classIndex < this.classVocabularies.size(); ++classIndex) {
            distance[classIndex] = this.distance(documentVocabulary, classIndex, beta);
            if (!this.normalize) continue;
            int n = classIndex;
            distance[n] = distance[n] / this.distance(new Vocabulary(), classIndex, betaZero);
        }
        return distance;
    }

    protected Double distance(Vocabulary documentVocabulary, int classIndex, double beta) {
        Double distance = 0.0;
        Set<String> vocabulary = this.createGlobalVocabulary();
        for (String word : vocabulary) {
            Double tpc = this.wordProbabilityInCategory(word, classIndex);
            Double tpd = this.wordProbabilityInDocument(word, documentVocabulary, beta);
            distance = distance + (tpc - tpd) * Math.log(tpc / tpd);
        }
        return distance;
    }

    protected Double wordProbabilityInCategory(String word, int classIndex) {
        Vocabulary classVocabulary = this.classVocabularies.get(classIndex);
        Double probability = classVocabulary.frequency(word);
        probability = probability == 0.0 || probability.equals(Double.NaN) ? Double.valueOf(this.estimateEpsilon()) : Double.valueOf(probability * this.getGamma(classIndex));
        return probability;
    }

    protected Double wordProbabilityInDocument(String word, Vocabulary documentVocabulary, double beta) {
        Double probability = documentVocabulary.frequency(word);
        probability = probability == 0.0 || probability.equals(Double.NaN) ? Double.valueOf(this.getEpsilon()) : Double.valueOf(probability * beta);
        return probability;
    }

    protected double getEpsilon() {
        if (this.espilon == null) {
            this.espilon = this.estimateEpsilon();
        }
        return this.espilon;
    }

    protected double estimateEpsilon() {
        Integer maxSize = 0;
        for (Vocabulary vocabulary : this.classVocabularies) {
            Integer candidate = vocabulary.totalCount();
            if (candidate <= maxSize) continue;
            maxSize = candidate;
        }
        return 1.0 / (this.thresholdFactor * maxSize.doubleValue());
    }

    protected double getGamma(int classIndex) {
        Double gamma = this.gammas.get(classIndex);
        if (gamma == null) {
            gamma = this.calculateGamma(classIndex);
            this.gammas.set(classIndex, gamma);
        }
        return gamma;
    }

    protected double calculateGamma(int classIndex) {
        Double gamma = 1.0;
        Double epsilon = this.getEpsilon();
        Vocabulary classVocabulary = this.classVocabularies.get(classIndex);
        Set<String> globalVocabulary = this.createGlobalVocabulary();
        for (String word : globalVocabulary) {
            if (classVocabulary.contains(word).booleanValue()) continue;
            gamma = gamma - epsilon;
        }
        return gamma;
    }

    protected double caculateBeta(Vocabulary documentVocabulary) {
        Double beta = 1.0;
        Double epsilon = this.getEpsilon();
        Set<String> globalVocabulary = this.createGlobalVocabulary();
        for (String word : globalVocabulary) {
            if (documentVocabulary.contains(word).booleanValue()) continue;
            beta = beta - epsilon;
        }
        return beta;
    }

    private Set<String> createGlobalVocabulary() {
        HashSet<String> vocabulary = new HashSet<String>();
        for (Vocabulary classVocabulary : this.classVocabularies) {
            vocabulary.addAll(classVocabulary.wordSet());
        }
        return vocabulary;
    }
}

