/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import mikera.util.Rand;
import mikera.vectorz.impl.ADenseArrayVector;
import nuroko.module.AOperationComponent;

public class Dropout
extends AOperationComponent {
    private double dropoutRate = 0.5;
    private final boolean[] dropped;
    private static final boolean DROPOUT_GRADIENTS = true;

    public Dropout(int length) {
        super(length);
        this.dropped = new boolean[length];
    }

    public Dropout(int length, double dropoutRate) {
        this(length);
        this.dropoutRate = dropoutRate;
    }

    @Override
    public void thinkInternal() {
        this.output.set((ADenseArrayVector)this.input);
    }

    @Override
    public void thinkInternalTraining() {
        this.output.set((ADenseArrayVector)this.input);
        if (this.dropoutRate > 0.0) {
            double scaleFactor = 1.0 / (1.0 - this.dropoutRate);
            double[] dt = this.output.getArray();
            for (int i = 0; i < dt.length; ++i) {
                boolean drop = Rand.chance((double)this.dropoutRate);
                if (drop) {
                    dt[i] = 0.0;
                } else {
                    int n = i;
                    dt[n] = dt[n] * scaleFactor;
                }
                this.dropped[i] = drop;
            }
        }
    }

    @Override
    public void trainGradientInternal(double factor) {
        double scaleFactor = 1.0 - this.dropoutRate;
        this.inputGradient.set((ADenseArrayVector)this.outputGradient);
        double[] ig = this.inputGradient.getArray();
        for (int i = 0; i < this.length; ++i) {
            if (this.dropped[i]) {
                ig[i] = 0.0;
                continue;
            }
            int n = i;
            ig[n] = ig[n] * scaleFactor;
        }
    }

    @Override
    public boolean hasDifferentTrainingThinking() {
        return true;
    }

    @Override
    public Dropout clone() {
        return new Dropout(this.getInputLength(), this.dropoutRate);
    }
}

