/*
 * Decompiled with CFR 0.152.
 */
package nuroko.testing;

import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import nuroko.module.AComponent;
import nuroko.module.loss.SquaredErrorLoss;
import org.junit.Assert;

public class DerivativeTest {
    private static final double EPS = 1.0E-6;

    public static void testDerivative(AComponent c) {
        c = c.clone();
        c.getGradient().fill(0.0);
        Vector t = Vector.createLength((int)c.getOutputLength());
        Vectorz.fillGaussian((AVector)t);
        Vector x = Vector.createLength((int)c.getInputLength());
        Vectorz.fillGaussian((AVector)x);
        c.train((AVector)x, (AVector)t, SquaredErrorLoss.INSTANCE, 1.0);
        AVector g = c.getGradient().clone();
        AVector p = c.getParameters().clone();
        AVector o = c.getOutput().clone();
        double L = -o.distanceSquared((AVector)t);
        int n = c.getParameterLength();
        for (int i = 0; i < n; ++i) {
            c.getParameters().set(p);
            c.getParameters().addAt(i, 1.0E-6);
            AVector y = c.think((AVector)x);
            double L2 = -y.distanceSquared((AVector)t);
            double expected = (L2 - L) / 1.0E-6;
            double calculated = g.get(i);
            if (!(Math.abs(calculated) > 1.0E-7) || calculated == 0.0 || expected == 0.0) continue;
            double d = expected / calculated;
            boolean ok = d >= 0.8 && d <= 1.2;
            Assert.assertTrue((String)("Gradient at position " + i + " expected=" + expected + " calculated=" + calculated), (boolean)ok);
        }
    }
}

