/*
 * Decompiled with CFR 0.152.
 */
package mikera.arrayz.impl;

import mikera.arrayz.INDArray;
import mikera.arrayz.impl.AbstractArray;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;

public class JoinedArray
extends AbstractArray<INDArray> {
    final int[] shape;
    final INDArray left;
    final INDArray right;
    final int dimension;
    final int split;

    private JoinedArray(INDArray left, INDArray right, int dim) {
        this.left = left;
        this.right = right;
        this.dimension = dim;
        this.shape = left.getShapeClone();
        this.split = this.shape[this.dimension];
        int n = this.dimension;
        this.shape[n] = this.shape[n] + right.getShape(this.dimension);
    }

    public static JoinedArray join(INDArray a, INDArray b, int dim) {
        int n = a.dimensionality();
        if (b.dimensionality() != n) {
            throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a, b));
        }
        for (int i = 0; i < n; ++i) {
            if (i == dim || a.getShape(i) == b.getShape(i)) continue;
            throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a, b));
        }
        return new JoinedArray(a, b, dim);
    }

    @Override
    public int dimensionality() {
        return this.shape.length;
    }

    @Override
    public int[] getShape() {
        return this.shape;
    }

    @Override
    public double get(int ... indexes) {
        if (indexes.length != this.dimensionality()) {
            throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, indexes));
        }
        int di = indexes[this.dimension];
        if (di < this.split) {
            return this.left.get(indexes);
        }
        indexes = (int[])indexes.clone();
        int n = this.dimension;
        indexes[n] = indexes[n] - this.split;
        return this.right.get(indexes);
    }

    @Override
    public void set(int[] indexes, double value) {
        if (indexes.length != this.dimensionality()) {
            throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, indexes));
        }
        int di = indexes[this.dimension];
        if (di < this.split) {
            this.left.set(indexes, value);
        } else {
            indexes = (int[])indexes.clone();
            int n = this.dimension;
            indexes[n] = indexes[n] - this.split;
            this.right.set(indexes, value);
        }
    }

    @Override
    public INDArray slice(int majorSlice) {
        if (this.dimension == 0) {
            return majorSlice < this.split ? this.left.slice(majorSlice) : this.right.slice(majorSlice - this.split);
        }
        return new JoinedArray(this.left.slice(majorSlice), this.right.slice(majorSlice), this.dimension - 1);
    }

    @Override
    public INDArray slice(int dimension, int index) {
        if (this.dimension == dimension) {
            return index < this.split ? this.left.slice(index) : this.right.slice(index - this.split);
        }
        if (dimension == 0) {
            return this.slice(index);
        }
        return new JoinedArray(this.left.slice(dimension - 1, index), this.right.slice(dimension - 1, index), this.dimension - 1);
    }

    @Override
    public int sliceCount() {
        return this.shape[0];
    }

    @Override
    public long elementCount() {
        return IntArrays.arrayProduct(this.shape);
    }

    @Override
    public boolean isView() {
        return true;
    }

    @Override
    public INDArray exactClone() {
        return new JoinedArray(this.left.exactClone(), this.right.exactClone(), this.dimension);
    }

    @Override
    public void validate() {
        if (this.left.getShape(this.dimension) + this.right.getShape(this.dimension) != this.shape[this.dimension]) {
            throw new Error("Inconsistent shape along split dimension");
        }
        super.validate();
    }
}

