/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import mikera.vectorz.AVector;
import mikera.vectorz.ArrayVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.Vector0;
import nuroko.module.AStateComponent;

public class Sparsifier
extends AStateComponent {
    private static final double MEAN_RATE = 0.001;
    private final double weight;
    private final double targetMean;
    private final Vector mean;

    public Sparsifier(int length, double targetMean, double weight) {
        super(length, length);
        this.targetMean = targetMean;
        this.mean = Vector.createLength((int)length);
        this.mean.fill(targetMean);
        this.weight = weight;
    }

    @Override
    public void thinkInternal() {
        this.output.set((AVector)this.input);
    }

    @Override
    public boolean hasDifferentTrainingThinking() {
        return false;
    }

    @Override
    public AVector getParameters() {
        return Vector0.INSTANCE;
    }

    @Override
    public AVector getGradient() {
        return Vector0.INSTANCE;
    }

    @Override
    public void trainGradientInternal(double factor) {
        this.mean.multiply(0.999);
        this.mean.addMultiple((ArrayVector)this.input, 0.001);
        int n = this.getInputLength();
        for (int i = 0; i < n; ++i) {
            double mi = this.mean.get(i);
            this.inputGradient.set(i, this.weight * (this.targetMean / mi - (1.0 - this.targetMean) / (1.0 - mi)));
        }
        this.inputGradient.add((ArrayVector)this.outputGradient);
    }

    @Override
    public Sparsifier clone() {
        Sparsifier s = new Sparsifier(this.getInputLength(), this.targetMean, this.weight);
        s.mean.set((AVector)this.mean);
        return s;
    }
}

