/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import java.util.List;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.ADenseArrayVector;
import mikera.vectorz.impl.Vector0;
import nuroko.core.IComponent;
import nuroko.module.AStateComponent;

public class Normaliser
extends AStateComponent {
    private final Vector mean;
    private final Vector stdev;

    private Normaliser(int length) {
        super(length, length);
        this.mean = Vector.createLength((int)length);
        this.stdev = Vector.createLength((int)length);
        this.stdev.fill(1.0);
    }

    public static Normaliser create(AVector mean, AVector stdev) {
        Normaliser n = new Normaliser(mean.length());
        n.mean.set(mean);
        n.stdev.set(stdev);
        return n;
    }

    public static Normaliser create(int length, double mean, double stdev) {
        Normaliser n = new Normaliser(length);
        n.mean.fill(mean);
        n.stdev.fill(stdev);
        return n;
    }

    @Override
    public void thinkInternal() {
        this.output.set((ADenseArrayVector)this.input);
        this.output.sub((ADenseArrayVector)this.mean);
        this.output.divide(this.stdev);
    }

    @Override
    public List<IComponent> getComponents() {
        return null;
    }

    @Override
    public AVector getParameters() {
        return Vector0.INSTANCE;
    }

    @Override
    public AVector getGradient() {
        return Vector0.INSTANCE;
    }

    @Override
    public void trainGradientInternal(double factor) {
        double[] og = this.outputGradient.getArray();
        double[] ig = this.inputGradient.getArray();
        double[] sd = this.stdev.getArray();
        for (int i = 0; i < og.length; ++i) {
            ig[i] = sd[i] == 0.0 ? og[i] : og[i] / sd[i];
        }
    }

    @Override
    public Normaliser clone() {
        Normaliser n = new Normaliser(this.getInputLength());
        n.mean.set((ADenseArrayVector)this.mean);
        n.stdev.set((ADenseArrayVector)this.stdev);
        return n;
    }

    @Override
    public void generate(AVector input, AVector output) {
        this.output.set(output);
        this.input.set(output);
        this.input.multiply(this.stdev);
        this.input.add(this.mean);
        input.set((ADenseArrayVector)this.input);
    }

    @Override
    public boolean hasDifferentTrainingThinking() {
        return false;
    }
}

