/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import smile.regression.LASSO;
import smile.regression.Regression;

public class ElasticNet
implements Regression<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private double lambda1 = 0.1;
    private double lambda2 = 0.1;
    private int p;
    private double[] w;
    private double b;
    private LASSO lasso;
    private double c;

    public ElasticNet(double[][] dArray, double[] dArray2, double d, double d2) {
        this(dArray, dArray2, d, d2, 1.0E-4, 1000);
    }

    public ElasticNet(double[][] dArray, double[] dArray2, double d, double d2, double d3, int n) {
        if (d <= 0.0) {
            throw new IllegalArgumentException("Please use Ridge instead, wrong L1 portion setting:" + d);
        }
        if (d2 <= 0.0) {
            throw new IllegalArgumentException("Please use LASSO instead, wrong L2 portion setting:" + d2);
        }
        this.lambda1 = d;
        this.lambda2 = d2;
        this.c = 1.0 / Math.sqrt(1.0 + d2);
        this.p = dArray[0].length;
        this.lasso = new LASSO(this.getAugmentedData(dArray), this.getAugmentedResponse(dArray2), this.lambda1 * this.c, d3, n);
        this.w = new double[this.lasso.coefficients().length];
        double d4 = 1.0 / this.c;
        for (int i = 0; i < this.w.length; ++i) {
            this.w[i] = d4 * this.lasso.coefficients()[i];
        }
        this.b = d4 * this.lasso.intercept();
    }

    public double predict(double[] dArray) {
        if (dArray.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", dArray.length, this.p));
        }
        return smile.math.Math.dot((double[])dArray, (double[])this.w) + this.b;
    }

    public double[] coefficients() {
        return this.w;
    }

    public double intercept() {
        return this.b;
    }

    public LASSO lasso() {
        return this.lasso;
    }

    private double[] getAugmentedResponse(double[] dArray) {
        double[] dArray2 = new double[dArray.length + this.p];
        System.arraycopy(dArray, 0, dArray2, 0, dArray.length);
        return dArray2;
    }

    private double[][] getAugmentedData(double[][] dArray) {
        int n;
        double[][] dArray2 = new double[dArray.length + this.p][this.p];
        double d = this.c * Math.sqrt(this.lambda2);
        for (n = 0; n < dArray.length; ++n) {
            for (int i = 0; i < this.p; ++i) {
                dArray2[n][i] = this.c * dArray[n][i];
            }
        }
        for (n = dArray.length; n < dArray2.length; ++n) {
            dArray2[n][n - dArray.length] = d;
        }
        return dArray2;
    }
}

