/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.features;

import com.aliasi.classify.Classified;
import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.features.FeatureExtractorFilter;
import com.aliasi.stats.OnlineNormalEstimator;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.FeatureExtractor;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

public class ZScoreFeatureExtractor<E>
extends FeatureExtractorFilter<E>
implements Serializable {
    static final long serialVersionUID = -5628628145432035433L;
    final Map<String, MeanDev> mFeatureToMeanDev;

    ZScoreFeatureExtractor(FeatureExtractor<? super E> extractor, Map<String, MeanDev> featureToMeanDev) {
        super(extractor);
        this.mFeatureToMeanDev = new LinkedHashMap<String, MeanDev>(featureToMeanDev);
    }

    public ZScoreFeatureExtractor(Corpus<ObjectHandler<Classified<E>>> corpus, FeatureExtractor<? super E> extractor) throws IOException {
        this(extractor, ZScoreFeatureExtractor.meanDevs(corpus, extractor));
    }

    @Override
    public Map<String, ? extends Number> features(E in) {
        Map<String, Number> featureMap = super.features(in);
        HashMap<String, Double> result = new HashMap<String, Double>();
        for (Map.Entry<String, MeanDev> featMeanDev : this.mFeatureToMeanDev.entrySet()) {
            String feature = featMeanDev.getKey();
            MeanDev meanDev = featMeanDev.getValue();
            Number n = featureMap.get(feature);
            double val = meanDev.zScore(n == null ? 0.0 : featureMap.get(feature).doubleValue());
            result.put(feature, val);
        }
        return result;
    }

    public double zScore(String feature, double value) {
        MeanDev meanDev = this.mFeatureToMeanDev.get(feature);
        return meanDev == null ? null : Double.valueOf(meanDev.zScore(value));
    }

    public double mean(String feature) {
        MeanDev meanDev = this.mFeatureToMeanDev.get(feature);
        return meanDev == null ? Double.NaN : meanDev.mMean;
    }

    public double standardDeviation(String feature) {
        MeanDev meanDev = this.mFeatureToMeanDev.get(feature);
        return meanDev == null ? Double.NaN : meanDev.mDev;
    }

    public Set<String> knownFeatures() {
        return Collections.unmodifiableSet(this.mFeatureToMeanDev.keySet());
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, MeanDev> entry : this.mFeatureToMeanDev.entrySet()) {
            String feature = entry.getKey();
            MeanDev meanDev = entry.getValue();
            sb.append("|");
            sb.append(feature);
            sb.append("| ");
            sb.append(meanDev);
            sb.append('\n');
        }
        return sb.toString();
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    static <F> Map<String, MeanDev> meanDevs(Corpus<ObjectHandler<Classified<F>>> corpus, final FeatureExtractor<? super F> extractor) throws IOException {
        final HashSet collectedFeatures = new HashSet();
        corpus.visitTrain(new ObjectHandler<Classified<F>>(){

            @Override
            public void handle(Classified<F> classified) {
                collectedFeatures.addAll(extractor.features(classified.getObject()).keySet());
            }
        });
        final HashMap featToEstimator = new HashMap();
        corpus.visitTrain(new ObjectHandler<Classified<F>>(){

            @Override
            public void handle(Classified<F> classified) {
                Object in = classified.getObject();
                for (String feature : collectedFeatures) {
                    Number value = extractor.features(in).get(feature);
                    double v = value == null ? 0.0 : value.doubleValue();
                    OnlineNormalEstimator estimator = (OnlineNormalEstimator)featToEstimator.get(feature);
                    if (estimator == null) {
                        estimator = new OnlineNormalEstimator();
                        featToEstimator.put(feature, estimator);
                    }
                    estimator.handle(v);
                }
            }
        });
        HashMap<String, MeanDev> result = new HashMap<String, MeanDev>();
        for (Map.Entry entry : featToEstimator.entrySet()) {
            String feat = (String)entry.getKey();
            OnlineNormalEstimator estimator = (OnlineNormalEstimator)entry.getValue();
            double mean = estimator.mean();
            double dev = estimator.standardDeviation();
            if (!(dev > 0.0)) continue;
            result.put(feat, new MeanDev(mean, dev));
        }
        return result;
    }

    static class Serializer<F>
    extends AbstractExternalizable {
        static final long serialVersionUID = 6365515337527915147L;
        private final ZScoreFeatureExtractor<F> mFilter;

        public Serializer() {
            this(null);
        }

        public Serializer(ZScoreFeatureExtractor<F> filter) {
            this.mFilter = filter;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeObject(this.mFilter.baseExtractor());
            out.writeInt(this.mFilter.mFeatureToMeanDev.size());
            for (Map.Entry<String, MeanDev> entry : this.mFilter.mFeatureToMeanDev.entrySet()) {
                out.writeUTF(entry.getKey());
                out.writeDouble(entry.getValue().mMean);
                out.writeDouble(entry.getValue().mDev);
            }
        }

        @Override
        public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
            FeatureExtractor extractor = (FeatureExtractor)in.readObject();
            int numFeats = in.readInt();
            HashMap<String, MeanDev> featureToMeanDev = new HashMap<String, MeanDev>(3 * numFeats / 2);
            for (int i = 0; i < numFeats; ++i) {
                String feature = in.readUTF();
                double mean = in.readDouble();
                double dev = in.readDouble();
                featureToMeanDev.put(feature, new MeanDev(mean, dev));
            }
            return new ZScoreFeatureExtractor(extractor, featureToMeanDev);
        }
    }

    static final class MeanDev {
        final double mMean;
        final double mDev;

        MeanDev(double mean, double dev) {
            this.mMean = mean;
            this.mDev = dev;
        }

        double zScore(double x) {
            return (x - this.mMean) / this.mDev;
        }

        public String toString() {
            return "mean=" + this.mMean + " dev=" + this.mDev;
        }
    }
}

