/*
 * Decompiled with CFR 0.152.
 */
package mikera.matrixx.impl;

import java.util.Arrays;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.impl.ASingleBandMatrix;
import mikera.matrixx.impl.DiagonalMatrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Tools;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.AArrayVector;
import mikera.vectorz.impl.ZeroVector;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.VectorzException;

public abstract class ADiagonalMatrix
extends ASingleBandMatrix {
    protected final int dimensions;

    public ADiagonalMatrix(int dimensions) {
        this.dimensions = dimensions;
    }

    @Override
    public int nonZeroBand() {
        return 0;
    }

    @Override
    public boolean isSquare() {
        return true;
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isDiagonal() {
        return true;
    }

    @Override
    public boolean isRectangularDiagonal() {
        return true;
    }

    @Override
    public boolean isUpperTriangular() {
        return true;
    }

    @Override
    public boolean isLowerTriangular() {
        return true;
    }

    @Override
    public abstract boolean isMutable();

    @Override
    public boolean isFullyMutable() {
        return false;
    }

    @Override
    public final int upperBandwidthLimit() {
        return 0;
    }

    @Override
    public final int lowerBandwidthLimit() {
        return 0;
    }

    @Override
    public AVector getBand(int band) {
        if (band == 0) {
            return this.getLeadingDiagonal();
        }
        if (band >= this.dimensions || band <= -this.dimensions) {
            return null;
        }
        return ZeroVector.create(this.bandLength(band));
    }

    @Override
    public AVector getNonZeroBand() {
        return this.getLeadingDiagonal();
    }

    @Override
    public double determinant() {
        double det = 1.0;
        for (int i = 0; i < this.dimensions; ++i) {
            det *= this.unsafeGetDiagonalValue(i);
        }
        return det;
    }

    public int dimensions() {
        return this.dimensions;
    }

    @Override
    public void copyRowTo(int row, double[] dest, int destOffset) {
        Arrays.fill(dest, destOffset, destOffset + this.dimensions, 0.0);
        dest[destOffset + row] = this.unsafeGetDiagonalValue(row);
    }

    @Override
    public void copyColumnTo(int col, double[] dest, int destOffset) {
        this.copyRowTo(col, dest, destOffset);
    }

    public ADiagonalMatrix innerProduct(ADiagonalMatrix a) {
        int dims = this.dimensions;
        if (dims != a.dimensions) {
            throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
        }
        DiagonalMatrix result = DiagonalMatrix.createDimensions(dims);
        for (int i = 0; i < dims; ++i) {
            result.data[i] = this.unsafeGetDiagonalValue(i) * a.unsafeGetDiagonalValue(i);
        }
        return result;
    }

    @Override
    public AMatrix innerProduct(AMatrix a) {
        if (a instanceof ADiagonalMatrix) {
            return this.innerProduct((ADiagonalMatrix)a);
        }
        if (a instanceof Matrix) {
            return this.innerProduct((Matrix)a);
        }
        if (this.dimensions != a.rowCount()) {
            throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
        }
        int acc = a.columnCount();
        Matrix m = Matrix.create(this.dimensions, acc);
        for (int i = 0; i < this.dimensions; ++i) {
            double dv = this.unsafeGetDiagonalValue(i);
            for (int j = 0; j < acc; ++j) {
                m.unsafeSet(i, j, dv * a.unsafeGet(i, j));
            }
        }
        return m;
    }

    @Override
    public Matrix innerProduct(Matrix a) {
        if (this.dimensions != a.rowCount()) {
            throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
        }
        int acc = a.columnCount();
        Matrix m = Matrix.create(this.dimensions, acc);
        for (int i = 0; i < this.dimensions; ++i) {
            double dv = this.unsafeGetDiagonalValue(i);
            for (int j = 0; j < acc; ++j) {
                m.unsafeSet(i, j, dv * a.unsafeGet(i, j));
            }
        }
        return m;
    }

    @Override
    public Matrix transposeInnerProduct(Matrix s) {
        return this.innerProduct(s);
    }

    @Override
    public void transformInPlace(AVector v) {
        if (v instanceof AArrayVector) {
            this.transformInPlace((AArrayVector)v);
            return;
        }
        if (v.length() != this.dimensions) {
            throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, v));
        }
        for (int i = 0; i < this.dimensions; ++i) {
            v.unsafeSet(i, v.unsafeGet(i) * this.unsafeGetDiagonalValue(i));
        }
    }

    @Override
    public void transformInPlace(AArrayVector v) {
        double[] data = v.getArray();
        int offset = v.getArrayOffset();
        for (int i = 0; i < this.dimensions; ++i) {
            int n = i + offset;
            data[n] = data[n] * this.unsafeGetDiagonalValue(i);
        }
    }

    @Override
    public void transform(Vector source, Vector dest) {
        int rc;
        int cc = rc = this.rowCount();
        if (source.length() != cc) {
            throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source));
        }
        if (dest.length() != rc) {
            throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest));
        }
        for (int row = 0; row < rc; ++row) {
            dest.data[row] = source.data[row] * this.unsafeGetDiagonalValue(row);
        }
    }

    @Override
    public int rowCount() {
        return this.dimensions;
    }

    @Override
    public int columnCount() {
        return this.dimensions;
    }

    @Override
    public boolean isIdentity() {
        for (int i = 0; i < this.dimensions; ++i) {
            if (this.unsafeGet(i, i) == 1.0) continue;
            return false;
        }
        return true;
    }

    @Override
    public boolean isBoolean() {
        for (int i = 0; i < this.dimensions; ++i) {
            if (Tools.isBoolean(this.unsafeGet(i, i))) continue;
            return false;
        }
        return true;
    }

    @Override
    public void transposeInPlace() {
    }

    @Override
    public double calculateElement(int i, AVector v) {
        return v.unsafeGet(i) * this.unsafeGetDiagonalValue(i);
    }

    @Override
    public void set(int row, int column, double value) {
        throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this, row, column));
    }

    public double getDiagonalValue(int i) {
        if (i < 0 || i >= this.dimensions) {
            throw new IndexOutOfBoundsException();
        }
        return this.unsafeGet(i, i);
    }

    public double unsafeGetDiagonalValue(int i) {
        return this.unsafeGet(i, i);
    }

    @Override
    public ADiagonalMatrix getTranspose() {
        return this;
    }

    @Override
    public ADiagonalMatrix getTransposeView() {
        return this;
    }

    @Override
    public double density() {
        return 1.0 / (double)this.dimensions;
    }

    @Override
    public Matrix toMatrix() {
        Matrix m = Matrix.create(this.dimensions, this.dimensions);
        for (int i = 0; i < this.dimensions; ++i) {
            m.data[i * (this.dimensions + 1)] = this.unsafeGetDiagonalValue(i);
        }
        return m;
    }

    @Override
    public final Matrix toMatrixTranspose() {
        return this.toMatrix();
    }

    @Override
    public void validate() {
        if (this.dimensions != this.getLeadingDiagonal().length()) {
            throw new VectorzException("dimension mismatch: " + this.dimensions);
        }
        super.validate();
    }

    @Override
    public abstract ADiagonalMatrix exactClone();
}

