/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import mikera.vectorz.AVector;
import mikera.vectorz.ArrayVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.Vector0;
import nuroko.module.AStateComponent;

public class Sparsifier
extends AStateComponent {
    private static final double MEAN_RATE = 0.001;
    private static final double LIMIT_THRESHOLD = 0.001;
    private static final double STANDARD_WEIGHT_FACTOR = 0.1;
    private final double weight;
    private final double targetMean;
    private final Vector mean;
    private double meanLearnRate = 0.001;

    public Sparsifier(int length, double targetMean, double weight) {
        super(length, length);
        this.targetMean = targetMean;
        this.mean = Vector.createLength((int)length);
        this.mean.fill(targetMean);
        this.weight = weight;
    }

    public Sparsifier(int length, double targetMean, double weight, double meanLearnRate) {
        this(length, targetMean, weight);
        this.meanLearnRate = meanLearnRate;
    }

    @Override
    public void thinkInternal() {
        this.output.set((AVector)this.input);
    }

    @Override
    public boolean hasDifferentTrainingThinking() {
        return false;
    }

    @Override
    public AVector getParameters() {
        return Vector0.INSTANCE;
    }

    @Override
    public AVector getGradient() {
        return Vector0.INSTANCE;
    }

    @Override
    public void trainGradientInternal(double factor) {
        this.mean.multiply(1.0 - this.meanLearnRate);
        this.mean.addMultiple((ArrayVector)this.input, this.meanLearnRate);
        double thisWeight = this.weight * 0.1;
        int n = this.getInputLength();
        for (int i = 0; i < n; ++i) {
            double mi = Math.min(0.999, Math.max(0.001, this.mean.get(i)));
            this.inputGradient.set(i, thisWeight * (this.targetMean / mi - (1.0 - this.targetMean) / (1.0 - mi)));
        }
        this.inputGradient.add((ArrayVector)this.outputGradient);
    }

    @Override
    public Sparsifier clone() {
        Sparsifier s = new Sparsifier(this.getInputLength(), this.targetMean, this.weight);
        s.mean.set((AVector)this.mean);
        s.meanLearnRate = this.meanLearnRate;
        return s;
    }
}

