/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module.layers;

import java.util.ArrayList;
import java.util.List;
import mikera.indexz.Index;
import mikera.indexz.Indexz;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrixx;
import mikera.vectorz.AVector;
import mikera.vectorz.GrowableVector;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.AArrayVector;
import mikera.vectorz.impl.Vector0;
import nuroko.module.AWeightLayer;

public final class SparseWeightLayer
extends AWeightLayer {
    private final Vector bias;
    private final Vector biasGradient;
    private final Index[] indexes;
    private final Vector[] weights;
    private final Vector[] weightGradients;
    private AVector parameters;
    private AVector gradient;

    public SparseWeightLayer(int inputLength, int outputLength, int maxLinks) {
        super(inputLength, outputLength);
        this.bias = Vector.createLength((int)outputLength);
        this.weights = new Vector[outputLength];
        this.indexes = new Index[outputLength];
        int links = Math.min(maxLinks, inputLength);
        Vector0 params = Vector0.INSTANCE;
        for (int j = 0; j < outputLength; ++j) {
            Vector wts = Vector.createLength((int)links);
            Index inds = Indexz.createRandomChoice((int)links, (int)inputLength);
            this.weights[j] = wts;
            this.indexes[j] = inds;
            params = params.join((AVector)wts);
        }
        params = params.join((AVector)this.bias);
        this.parameters = params;
        this.biasGradient = Vector.createLength((int)outputLength);
        this.weightGradients = new Vector[outputLength];
        Vector0 g = Vector0.INSTANCE;
        for (int j = 0; j < outputLength; ++j) {
            Vector grd;
            this.weightGradients[j] = grd = Vector.createLength((int)this.weights[j].length());
            g = g.join((AVector)grd);
        }
        g = g.join((AVector)this.biasGradient);
        this.gradient = g;
        assert (this.gradient.length() == this.parameters.length());
    }

    public SparseWeightLayer(SparseWeightLayer wl) {
        super(wl.inputLength, wl.outputLength);
        this.bias = wl.bias.clone();
        this.biasGradient = wl.biasGradient.clone();
        this.weights = new Vector[this.outputLength];
        this.indexes = new Index[this.outputLength];
        Vector params = this.bias;
        for (int j = 0; j < this.outputLength; ++j) {
            Vector wts = wl.weights[j].clone();
            Index inds = wl.indexes[j].clone();
            this.weights[j] = wts;
            this.indexes[j] = inds;
            params = params.join((AVector)wts);
        }
        this.parameters = params;
        this.weightGradients = new Vector[this.biasGradient.length()];
        Vector g = this.biasGradient;
        for (int j = 0; j < this.outputLength; ++j) {
            Vector grd;
            this.weightGradients[j] = grd = wl.weightGradients[j].clone();
            g = g.join((AVector)grd);
        }
        this.gradient = g;
        assert (this.gradient.length() == this.parameters.length());
    }

    @Override
    public AMatrix asMatrix() {
        return Matrixx.createSparse((int)this.getInputLength(), (Index[])this.indexes, (AVector[])this.weights);
    }

    @Override
    public AVector getParameters() {
        return this.parameters;
    }

    @Override
    public int getParameterLength() {
        return this.parameters.length();
    }

    @Override
    public void think(AVector input, AVector output) {
        assert (this.inputLength == input.length());
        assert (this.outputLength == output.length());
        this.setInput(input);
        this.thinkInternal();
        output.set((AVector)this.getOutput());
    }

    @Override
    public void thinkInternal() {
        for (int i = 0; i < this.outputLength; ++i) {
            double val = this.bias.get(i);
            Vector wts = this.weights[i];
            Index inds = this.indexes[i];
            this.output.set(i, val += this.input.dotProduct(wts, inds));
        }
    }

    @Override
    public void generate(AVector input, AVector output) {
        input.fill(0.0);
        for (int i = 0; i < this.outputLength; ++i) {
            input.addMultiple(this.weights[i], this.indexes[i], output.unsafeGet(i));
        }
    }

    @Override
    public AVector getGradient() {
        return this.gradient;
    }

    @Override
    public void trainGradientInternal(double factor) {
        this.inputGradient.fill(0.0);
        this.biasGradient.addMultiple((AArrayVector)this.outputGradient, factor *= this.getLearnFactor());
        for (int j = 0; j < this.outputLength; ++j) {
            double grad = this.outputGradient.get(j);
            this.weightGradients[j].addMultiple(this.indexes[j], this.input, grad * factor);
            this.inputGradient.addMultiple(this.weights[j], this.indexes[j], grad);
        }
    }

    @Override
    public SparseWeightLayer clone() {
        SparseWeightLayer wl = new SparseWeightLayer(this);
        wl.getParameters().set(this.getParameters());
        return wl;
    }

    @Override
    public int getLinkCount(int outputIndex) {
        return this.weights[outputIndex].length();
    }

    @Override
    public double getLinkWeight(int outputIndex, int number) {
        return this.weights[outputIndex].data[number];
    }

    @Override
    public int getLinkSource(int outputIndex, int number) {
        return this.indexes[outputIndex].data[number];
    }

    @Override
    public void initRandom() {
        Vectorz.fillGaussian((AVector)this.bias, (double)0.0, (double)0.3);
        for (Vector v : this.weights) {
            Vectorz.fillGaussian((AVector)v, (double)0.0, (double)(1.0 / Math.sqrt(v.length())));
        }
    }

    @Override
    public Index getSourceIndex(int outputIndex) {
        return this.indexes[outputIndex];
    }

    @Override
    public AVector getSourceWeights(int outputIndex) {
        return this.weights[outputIndex];
    }

    @Override
    public SparseWeightLayer getInverse() {
        int i;
        int inps = this.getInputLength();
        int outps = this.getOutputLength();
        SparseWeightLayer wl = new SparseWeightLayer(this.getOutputLength(), this.getInputLength(), 0);
        GrowableVector[] weightVectors = new GrowableVector[inps];
        ArrayList[] indexVectors = new ArrayList[inps];
        for (i = 0; i < inps; ++i) {
            weightVectors[i] = new GrowableVector();
            indexVectors[i] = new ArrayList();
        }
        for (int j = 0; j < outps; ++j) {
            AVector owts = this.getSourceWeights(j);
            Index oixs = this.getSourceIndex(j);
            assert (owts.length() == oixs.length());
            for (int i2 = 0; i2 < owts.length(); ++i2) {
                int si = oixs.get(i2);
                indexVectors[si].add(j);
                weightVectors[si].append(owts.get(i2));
            }
        }
        for (i = 0; i < inps; ++i) {
            wl.weights[i] = new Vector((AVector)weightVectors[i]);
            wl.indexes[i] = Indexz.create((List)indexVectors[i]);
            wl.weightGradients[i] = Vector.createLength((int)wl.weights[i].length());
        }
        wl.rebuildVectors();
        return wl;
    }

    private void rebuildVectors() {
        Vector0 params = Vector0.INSTANCE;
        Vector0 grads = Vector0.INSTANCE;
        for (int i = 0; i < this.outputLength; ++i) {
            params = params.join((AVector)this.weights[i]);
            grads = grads.join((AVector)this.weightGradients[i]);
        }
        params = params.join((AVector)this.bias);
        grads = grads.join((AVector)this.biasGradient);
        this.parameters = params;
        this.gradient = grads;
    }

    @Override
    public boolean hasDifferentTrainingThinking() {
        return false;
    }

    public Vector getBias() {
        return this.bias;
    }
}

