package mikera.arrayz;

import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import us.bpsm.edn.parser.Parseable;
import us.bpsm.edn.parser.Parser;
import us.bpsm.edn.parser.Parsers;
import mikera.arrayz.impl.SliceArray;
import mikera.matrixx.Matrix;
import mikera.matrixx.Matrixx;
import mikera.matrixx.impl.StridedMatrix;
import mikera.vectorz.AScalar;
import mikera.vectorz.AVector;
import mikera.vectorz.Scalar;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.ArrayIndexScalar;
import mikera.vectorz.impl.ArraySubVector;
import mikera.vectorz.impl.Vector0;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;

/**
 * Static function class for array operations
 * 
 * @author Mike
 */
public class Arrayz {
	/**
	 * Creates an array from the given data
	 * 
	 * Handles double arrays, INDArray instances, and lists
	 * 
	 * @param object
	 * @return
	 */
	@SuppressWarnings("unchecked")
	public static INDArray create(Object object) {
		if (object instanceof INDArray) return create((INDArray)object);
		
		if (object instanceof double[]) return Vector.of((double[])object);
		if (object instanceof List<?>) {
			List<?> list=(List<Object>) object;
			if (list.size()==0) return Vector0.INSTANCE;
			Object o1=list.get(0);
			if ((o1 instanceof AScalar)||(o1 instanceof Number)) {
				return Vectorz.create((List<Object>)object);
			} else if (o1 instanceof AVector) {
				return Matrixx.create((List<Object>)object);
			} else if (o1 instanceof INDArray) {
				return SliceArray.create((List<INDArray>)object);				
			} else {
				ArrayList<INDArray> al=new ArrayList<INDArray>();
				for (Object o: list) {
					al.add(create(o));
				}
				return Arrayz.create(al);
			}
		}
		
		if (object instanceof Number) return Scalar.create(((Number)object).doubleValue());
		
		if (object.getClass().isArray()) {
			return create(Arrays.asList((Object[])object));
		}
		
		throw new VectorzException("Don't know how to create array from: "+object.getClass());
	}
	
	/**
	 * Create a new array instance with the given shape. New array will be filled with zeroes.
	 * 
	 * @param shape
	 * @return
	 */
	public static INDArray newArray(int... shape) {
		int dims=shape.length;
		
		switch (dims) {
			case 0: return Scalar.create(0.0);
			case 1: return Vector.createLength(shape[0]);
			case 2: return Matrix.create(shape[0], shape[1]);
			default: return Array.newArray(shape);
		}
	}
	
	public static INDArray create(INDArray a) {
		int dims=a.dimensionality();
		switch (dims) {
		case 0:
			return Scalar.create(a.get());
		case 1:
			return Vector.create(a.toDoubleArray());
		case 2:
			return Matrix.wrap(a.getShape(0), a.getShape(1), a.toDoubleArray());
		default:
			return Array.wrap(a.toDoubleArray(),a.getShape());
		}
	}
	
	/**
	 * Creates an array using the given data as slices.
	 * 
	 * @param data
	 * @return
	 */
	public static INDArray create(Object... data) {
		return create((Object)data);
	}
	
	/**
	 * Creates an INDArray instance wrapping the given double data, with the provided shape.
	 * 
	 * @param data
	 * @param shape
	 * @return
	 */
	public static INDArray wrap(double[] data, int[] shape) {
		int dlength=data.length;
		switch (shape.length) {
			case 0:
				return ArrayIndexScalar.wrap(data,0);
				
			case 1:
				int n=shape[0];
				if (dlength<n) throw new IllegalArgumentException(ErrorMessages.insufficientElements(dlength));
				if (n==dlength) {
					return Vector.wrap(data); 
				} else {
					return ArraySubVector.wrap(data, 0, n);
				}
				
			case 2:
				int rc=shape[0], cc=shape[1];
				int ec=rc*cc;
				if (dlength<ec) throw new IllegalArgumentException(ErrorMessages.insufficientElements(dlength));
				if (ec==dlength) {
					return Matrix.wrap(rc,cc, data);
				} else {
					return StridedMatrix.wrap(data, shape[0], shape[1], 0, shape[1], 1);
				}
		
			default:
				long eec=IntArrays.arrayProduct(shape);
				if (dlength<eec) throw new IllegalArgumentException(ErrorMessages.insufficientElements(dlength));
				if (eec==dlength) {
					return Array.wrap(data, shape);
				} else {
					return NDArray.wrap(data, shape);
				}
		}
	}

	/**
	 * Creates a new array using the elements in the specified vector.
	 * Truncates or zero-pads the data as required to fill the new array
	 * @param data
	 * @param rows
	 * @param columns
	 * @return
	 */
	public static INDArray createFromVector(AVector a, int... shape) {
		int dims=shape.length;
		if (dims==0) {
			return Scalar.createFromVector(a);
		} else if (dims==1) {
			return Vector.createFromVector(a,shape[0]);
		} else if (dims==2) {
			return Matrixx.createFromVector(a, shape[0], shape[1]);
		} else {
			return Array.createFromVector(a,shape);
		}
	}
	
	public static INDArray load(Reader reader) {
		Parseable pbr=Parsers.newParseable(reader);
		Parser p = Parsers.newParser(Parsers.defaultConfiguration());
		return Arrayz.create(p.nextValue(pbr));
	}
	
	/**
	 * Parse an array from a String. String should be in edn format
	 * 
	 * @param ednString
	 * @return
	 */
	public static INDArray parse(String ednString) {
		return load(new StringReader(ednString));	
	}

	public static INDArray wrapStrided(double[] data, int offset, int[] shape, int[] strides) {
		int dims=shape.length;
		if (dims==0) {
			return ArrayIndexScalar.wrap(data, offset);
		} else if (dims==1) {
			return Vectorz.wrapStrided(data, offset, shape[0], strides[0]);
		} else if (dims==2) {
			return Matrixx.wrapStrided(data, shape[0],shape[1], offset, strides[0],strides[1]);
		} else {
			if (isPackedLayout(data,offset,shape,strides)) {
				return Array.wrap(data, shape);
			} else {
				return NDArray.wrapStrided(data,offset,shape,strides);
			}
		}
	}
	
	public static boolean isPackedLayout(double[] data, int offset, int[] shape, int[] strides) {
		if (offset!=0) return false;
		int dims=shape.length;
		int st=1;
		for (int i=dims-1; i>=0; i--) {
			if (strides[i]!=st) return false;
			st*=shape[i];
		}
		return (st==data.length);
	}

	/**
	 * Checks if the given set of strides represents a fully packed, row major layout for the given shape
	 * @param shape
	 * @param strides
	 * @return
	 */
	public static boolean isPackedStrides(int[] shape, int[] strides) {
		int dims=shape.length;
		int st=1;
		for (int i=dims-1; i>=0; i--) {
			if (strides[i]!=st) return false;
			st*=shape[i];
		}
		return true;
	}
}
