/*
 * Decompiled with CFR 0.152.
 */
package opennlp.ccg.perceptron;

import java.io.IOException;
import opennlp.ccg.perceptron.Alphabet;
import opennlp.ccg.perceptron.EventFile;
import opennlp.ccg.perceptron.Model;

public class Trainer {
    public final String trainfile;
    public final Alphabet alphabet;
    public final int iterations;
    public final String modelfile;
    public final boolean inMemory;
    public final Model currentModel;
    public final Model averagedModel;
    private Model iterationModel;
    private int saveEveryNth = 0;

    public Trainer(String trainfile, String alphabetfile, int iterations, String modelfile, boolean inMemory) throws IOException {
        this.trainfile = trainfile;
        this.alphabet = new Alphabet(alphabetfile);
        this.iterations = iterations;
        this.modelfile = modelfile;
        this.inMemory = inMemory;
        this.currentModel = new Model(this.alphabet);
        this.averagedModel = new Model(this.alphabet);
        this.iterationModel = new Model(this.alphabet);
    }

    public void initModel(String initmodelfile) throws IOException {
        Model model = new Model(initmodelfile);
        this.currentModel.set(model);
        this.averagedModel.set(model);
    }

    public void train() throws IOException {
        EventFile.Event actualBest;
        EventFile eventFile = new EventFile(this.trainfile, this.alphabet, this.inMemory);
        boolean converged = false;
        for (int i = 0; i < this.iterations; ++i) {
            EventFile.Block block;
            System.out.println("iteration: " + i);
            eventFile.reset();
            this.iterationModel.zero();
            int updates = 0;
            int correct = 0;
            int total = 0;
            while ((block = eventFile.nextBlock()) != null) {
                EventFile.Event avgModelBest;
                ++total;
                EventFile.Event modelBest = this.currentModel.best(block);
                if (modelBest != (actualBest = block.best())) {
                    ++updates;
                    this.currentModel.add(actualBest.features);
                    this.currentModel.subtract(modelBest.features);
                }
                if ((avgModelBest = this.averagedModel.best(block)) == actualBest) {
                    ++correct;
                }
                this.iterationModel.add(this.currentModel);
            }
            double denominator = 1.0 * (double)total * (double)(i + 1);
            this.iterationModel.multiply(1.0 / denominator);
            if (i > 0) {
                double mult = 1.0 * (double)i / (double)(i + 1);
                this.averagedModel.multiply(mult);
            }
            this.averagedModel.add(this.iterationModel);
            System.out.println("updates: " + updates);
            System.out.println("avg model correct: " + correct + " total: " + total + " accuracy: " + 1.0 * (double)correct / (double)total);
            System.out.println();
            if (updates == 0) {
                System.out.println("converged");
                System.out.println();
                converged = true;
                break;
            }
            if (this.saveEveryNth <= 0 || i >= this.iterations - 1 || i % this.saveEveryNth != 0) continue;
            String nthModelfile = Trainer.nthFilename(this.modelfile, i);
            System.out.println("Saving model to " + nthModelfile);
            this.averagedModel.save(nthModelfile);
            System.out.println();
        }
        if (!converged) {
            EventFile.Block block;
            System.out.println("final iteration: ");
            eventFile.reset();
            int finalCorrect = 0;
            int correct = 0;
            int total = 0;
            while ((block = eventFile.nextBlock()) != null) {
                ++total;
                EventFile.Event modelBest = this.currentModel.best(block);
                EventFile.Event avgModelBest = this.averagedModel.best(block);
                actualBest = block.best();
                if (modelBest == actualBest) {
                    ++finalCorrect;
                }
                if (avgModelBest != actualBest) continue;
                ++correct;
            }
            System.out.println("final model correct: " + finalCorrect + " total: " + total + " accuracy: " + 1.0 * (double)finalCorrect / (double)total);
            System.out.println("avg model correct: " + correct + " total: " + total + " accuracy: " + 1.0 * (double)correct / (double)total);
            System.out.println();
        }
        eventFile.close();
    }

    public static String nthFilename(String filename, int N) {
        int lastdot = filename.lastIndexOf(46);
        if (lastdot > 0) {
            return filename.substring(0, lastdot) + "." + N + filename.substring(lastdot);
        }
        return filename + "." + N;
    }

    public static void main(String[] args) throws IOException {
        if (args.length < 4) {
            System.out.println("Usage: java perceptron.Trainer <traineventfile> <alphabetfile> <iterations> <modelfile> (-i <initmodelfile>) (-f <finalmodelfile>) (-n <save-every-nth>) (-in_mem)");
            System.exit(0);
        }
        String traineventfile = args[0];
        String alphabetfile = args[1];
        int iterations = Integer.parseInt(args[2]);
        String modelfile = args[3];
        String initmodelfile = null;
        String finalmodelfile = null;
        int saveEveryNth = 0;
        boolean inMemory = false;
        for (int i = 4; i < args.length; ++i) {
            if (args[i].equals("-i")) {
                initmodelfile = args[++i];
            }
            if (args[i].equals("-f")) {
                finalmodelfile = args[++i];
            }
            if (args[i].equals("-n")) {
                saveEveryNth = Integer.parseInt(args[++i]);
            }
            if (!args[i].equals("-in_mem")) continue;
            inMemory = true;
        }
        System.out.println("Training on " + traineventfile + " using " + alphabetfile + " for " + iterations + " iterations");
        if (initmodelfile != null) {
            System.out.println("with " + initmodelfile + " as the initial model");
        }
        if (inMemory) {
            System.out.println("keeping events in memory");
        }
        System.out.println();
        Trainer trainer = new Trainer(traineventfile, alphabetfile, iterations, modelfile, inMemory);
        if (initmodelfile != null) {
            trainer.initModel(initmodelfile);
        }
        trainer.saveEveryNth = saveEveryNth;
        trainer.train();
        System.out.println("Saving model to " + modelfile);
        trainer.averagedModel.save(modelfile);
        if (finalmodelfile != null) {
            System.out.println("Saving model to " + finalmodelfile);
            trainer.currentModel.save(finalmodelfile);
        }
    }
}

