/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.types.DataType;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtPredictor;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class PtModel
extends BaseModel {
    PtModel(Device device) {
        device = Device.defaultIfNull((Device)device);
        this.manager = PtNDManager.getSystemManager().newSubManager(device);
        this.dataType = DataType.FLOAT32;
    }

    public void load(Path modelPath, String modelName, Map<String, String> options) throws IOException, MalformedModelException {
        this.modelDir = modelPath.toAbsolutePath();
        this.modelName = modelName;
        if (this.block == null) {
            Path modelFile = this.modelDir.resolve(modelName + ".pt");
            if (Files.notExists(modelFile, new LinkOption[0])) {
                throw new FileNotFoundException(".pt file not found in: " + modelPath);
            }
            this.block = JniUtils.loadModule((PtNDManager)this.manager, modelFile, this.manager.getDevice());
        } else {
            this.readParameters(options);
        }
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        return new PtPredictor<I, O>(this, translator, false);
    }

    public String[] getArtifactNames() {
        try {
            List files = Files.walk(this.modelDir, new FileVisitOption[0]).filter(x$0 -> Files.isRegularFile(x$0, new LinkOption[0])).collect(Collectors.toList());
            ArrayList<String> ret = new ArrayList<String>(files.size());
            for (Path path : files) {
                String fileName = path.toFile().getName();
                if (fileName.endsWith(".pt")) continue;
                Path relative = this.modelDir.relativize(path);
                ret.add(relative.toString());
            }
            return ret.toArray(new String[0]);
        }
        catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public void close() {
        this.manager.close();
    }
}

