/*
 * Decompiled with CFR 0.152.
 */
package com.github.pmerienne.trident.ml.regression;

import com.github.pmerienne.trident.ml.regression.Regressor;
import com.github.pmerienne.trident.ml.util.MathUtil;

public class PerceptronRegressor
implements Regressor {
    private double[] weights;
    public double learningRate = 0.1;

    public PerceptronRegressor() {
    }

    public PerceptronRegressor(double learningRate) {
        this.learningRate = learningRate;
    }

    @Override
    public Double predict(double[] features) {
        if (this.weights == null) {
            this.initWeights(features.length);
        }
        Double prediction = MathUtil.dot(this.weights, features);
        return prediction;
    }

    @Override
    public void update(Double expected, double[] features) {
        Double prediction = this.predict(features);
        if (!expected.equals(prediction)) {
            Double error = expected - prediction;
            for (int i = 0; i < features.length; ++i) {
                Double correction = features[i] * error * this.learningRate;
                this.weights[i] = this.weights[i] + correction;
            }
        }
    }

    protected void initWeights(int size) {
        this.weights = new double[size];
    }

    @Override
    public void reset() {
        this.weights = null;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public void setWeights(double[] weights) {
        this.weights = weights;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    public String toString() {
        return "PerceptronRegressor [learningRate=" + this.learningRate + "]";
    }
}

