/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.jni;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.pytorch.jni.PyTorchLibrary;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import java.util.stream.Stream;

public final class IValueUtils {
    private IValueUtils() {
    }

    public static long toIValuePointer(long arrayHandle) {
        return PyTorchLibrary.LIB.iValueFromTensor(arrayHandle);
    }

    public static long iValueFromList(long[] pointers) {
        return PyTorchLibrary.LIB.iValueFromList(pointers);
    }

    public static long iValueFromDict(long[] pointers, String[] names) {
        return PyTorchLibrary.LIB.iValueFromDict(pointers, names);
    }

    public static boolean isNDArray(long iValueHandle) {
        return PyTorchLibrary.LIB.iValueIsTensor(iValueHandle);
    }

    public static boolean isNDList(long iValueHandle) {
        return PyTorchLibrary.LIB.iValueIsTensorList(iValueHandle);
    }

    public static boolean isList(long iValueHandle) {
        return PyTorchLibrary.LIB.iValueIsList(iValueHandle);
    }

    public static boolean isTuple(long iValueHandle) {
        return PyTorchLibrary.LIB.iValueIsTuple(iValueHandle);
    }

    public static boolean isMap(long iValueHandle) {
        return PyTorchLibrary.LIB.iValueIsMap(iValueHandle);
    }

    public static boolean isString(long iValueHandle) {
        return PyTorchLibrary.LIB.iValueIsString(iValueHandle);
    }

    public static PtNDArray toNDArray(long iValueHandle, PtNDManager manager) {
        long ndHandle = PyTorchLibrary.LIB.iValueToTensor(iValueHandle);
        return new PtNDArray(manager, ndHandle);
    }

    public static NDList toNDList(long iValueHandle, PtNDManager manager) {
        long[] ndHandles = PyTorchLibrary.LIB.iValueToTensorList(iValueHandle);
        NDList list = new NDList();
        for (long handle : ndHandles) {
            list.add((Object)new PtNDArray(manager, handle));
        }
        return list;
    }

    public static String toString(long iValueHandle) {
        return PyTorchLibrary.LIB.iValueToString(iValueHandle);
    }

    public static long[] toIValueArray(long iValueHandle) {
        if (IValueUtils.isTuple(iValueHandle)) {
            return PyTorchLibrary.LIB.iValueToListFromTuple(iValueHandle);
        }
        return PyTorchLibrary.LIB.iValueToList(iValueHandle);
    }

    public static Map<Long, Long> toIValueMap(long iValueHandle) {
        long[] iValueHandles = PyTorchLibrary.LIB.iValueToMap(iValueHandle);
        ConcurrentHashMap<Long, Long> map = new ConcurrentHashMap<Long, Long>();
        for (int i = 0; i < iValueHandles.length; i += 2) {
            map.put(iValueHandles[i], iValueHandles[i + 1]);
        }
        return map;
    }

    private static NDList forwardHelper(long iValueHandle, PtNDManager manager) {
        NDList list = new NDList();
        if (IValueUtils.isNDArray(iValueHandle)) {
            list.add((Object)IValueUtils.toNDArray(iValueHandle, manager));
        } else if (IValueUtils.isNDList(iValueHandle)) {
            list.addAll(IValueUtils.toNDList(iValueHandle, manager));
        } else if (IValueUtils.isList(iValueHandle) || IValueUtils.isTuple(iValueHandle)) {
            for (long handle : IValueUtils.toIValueArray(iValueHandle)) {
                list.addAll(IValueUtils.forwardHelper(handle, manager));
            }
        } else if (IValueUtils.isMap(iValueHandle)) {
            Map<Long, Long> map = IValueUtils.toIValueMap(iValueHandle);
            for (Map.Entry<Long, Long> entry : map.entrySet()) {
                String name = IValueUtils.toString(entry.getKey());
                PyTorchLibrary.LIB.torchDeleteIValue(entry.getKey());
                PtNDArray value = IValueUtils.toNDArray(entry.getValue(), manager);
                PyTorchLibrary.LIB.torchDeleteIValue(entry.getValue());
                value.setName(name);
                list.add((Object)value);
            }
        } else {
            PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle);
            throw new UnsupportedOperationException("Unsupported IValue type");
        }
        PyTorchLibrary.LIB.torchDeleteIValue(iValueHandle);
        return list;
    }

    public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain) {
        long[] arrayHandles = inputs.stream().mapToLong(input -> (Long)((PtNDArray)((Object)input)).getHandle()).toArray();
        String[] names = (String[])inputs.stream().map(NDArray::getName).toArray(String[]::new);
        long[] iValueInputs = IValueUtils.getInputs(arrayHandles, names);
        long result = PyTorchLibrary.LIB.moduleForward(block.getHandle(), iValueInputs, isTrain);
        PtNDManager manager = (PtNDManager)((NDArray)inputs.get(0)).getManager();
        return IValueUtils.forwardHelper(result, manager);
    }

    private static boolean isNameList(String name) {
        return Pattern.matches("\\w+\\[]", name);
    }

    private static boolean isNameDict(String name) {
        return name.contains(".");
    }

    private static long[] getInputs(long[] arrays, String[] names) {
        ArrayList<PairList> outputs = new ArrayList<PairList>();
        ConcurrentHashMap<String, Integer> indexMap = new ConcurrentHashMap<String, Integer>();
        for (int i = 0; i < arrays.length; ++i) {
            String name = names[i];
            if (name == null || !IValueUtils.isNameList(name) && !IValueUtils.isNameDict(name)) {
                PairList list = new PairList();
                list.add(new Pair(null, (Object)IValueUtils.toIValuePointer(arrays[i])));
                outputs.add(list);
                continue;
            }
            String mapKey = null;
            boolean isDict = IValueUtils.isNameDict(names[i]);
            if (isDict) {
                String[] strings = names[i].split("\\.");
                Preconditions.checkArgument((strings.length == 2 ? 1 : 0) != 0, (String)"Please make sure you only include one '.' in the name. Nested Map is not supported!");
                name = strings[0];
                mapKey = strings[1];
            }
            if (!indexMap.containsKey(name)) {
                outputs.add(new PairList());
                indexMap.put(name, outputs.size() - 1);
            }
            if (isDict) {
                ((PairList)outputs.get((Integer)indexMap.get(name))).add(new Pair((Object)mapKey, (Object)arrays[i]));
                continue;
            }
            ((PairList)outputs.get((Integer)indexMap.get(name))).add(new Pair((Object)name, (Object)arrays[i]));
        }
        long[] pointers = new long[outputs.size()];
        for (int i = 0; i < outputs.size(); ++i) {
            if (((PairList)outputs.get(i)).size() == 1 && ((PairList)outputs.get(i)).get(0).getKey() == null) {
                pointers[i] = (Long)((PairList)outputs.get(i)).get(0).getValue();
                continue;
            }
            if (IValueUtils.isNameList((String)((PairList)outputs.get(i)).get(0).getKey())) {
                pointers[i] = IValueUtils.iValueFromList(IValueUtils.toPrimitiveLongArray((Long[])((PairList)outputs.get(i)).valueArray((Object[])new Long[0])));
                continue;
            }
            PairList dict = (PairList)outputs.get(i);
            pointers[i] = IValueUtils.iValueFromDict(IValueUtils.toPrimitiveLongArray((Long[])dict.valueArray((Object[])new Long[0])), (String[])dict.keyArray((Object[])new String[0]));
        }
        return pointers;
    }

    private static long[] toPrimitiveLongArray(Long[] array) {
        if (array == null) {
            return null;
        }
        if (array.length == 0) {
            return new long[0];
        }
        return Stream.of(array).mapToLong(Long::longValue).toArray();
    }
}

