/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import smile.math.matrix.Cholesky;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.math.matrix.SVD;
import smile.regression.OnlineRegression;
import smile.regression.RegressionTrainer;

public class RLS
implements OnlineRegression<double[]> {
    private static final long serialVersionUID = 1L;
    private int p;
    private double[] w;
    private double lambda;
    private DenseMatrix V;
    private double[] x1;
    private double[] Vx;

    public RLS(double[][] x, double[] y) {
        this(x, y, 1.0);
    }

    public RLS(double[][] x, double[] y, double lambda) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (lambda <= 0.0 || lambda > 1.0) {
            throw new IllegalArgumentException("The forgetting factor must be in (0, 1]");
        }
        this.lambda = lambda;
        int n = x.length;
        this.p = x[0].length;
        if (n <= this.p) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", n, this.p));
        }
        DenseMatrix X = Matrix.zeros((int)n, (int)(this.p + 1));
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < this.p; ++j) {
                X.set(i, j, x[i][j]);
            }
            X.set(i, this.p, 1.0);
        }
        this.w = new double[this.p + 1];
        SVD svd = X.svd();
        svd.solve(y, this.w);
        Cholesky cholesky = svd.CholeskyOfAtA();
        this.V = cholesky.inverse();
        this.Vx = new double[this.p + 1];
        this.x1 = new double[this.p + 1];
        this.x1[this.p] = 1.0;
    }

    public double[] coefficients() {
        return this.w;
    }

    @Override
    public double predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double y = this.w[this.p];
        for (int i = 0; i < x.length; ++i) {
            y += x[i] * this.w[i];
        }
        return y;
    }

    public void learn(double[][] x, double[] y) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("Input vector x of size %d not equal to length %d of y", x.length, y.length));
        }
        for (int i = 0; i < x.length; ++i) {
            this.learn(x[i], y[i]);
        }
    }

    @Override
    public void learn(double[] x, double y) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        System.arraycopy(x, 0, this.x1, 0, this.p);
        double v = 1.0 + this.V.xax(this.x1);
        if (Double.isNaN(1.0 / v)) {
            throw new IllegalStateException("The updated V matrix is no longer invertible.");
        }
        this.V.ax(this.x1, this.Vx);
        for (int j = 0; j <= this.p; ++j) {
            for (int i = 0; i <= this.p; ++i) {
                double tmp = this.V.get(i, j) - this.Vx[i] * this.Vx[j] / v;
                this.V.set(i, j, tmp / this.lambda);
            }
        }
        this.V.ax(this.x1, this.Vx);
        double err = y - this.predict(x);
        for (int i = 0; i <= this.p; ++i) {
            int n = i;
            this.w[n] = this.w[n] + this.Vx[i] * err;
        }
    }

    public double getForgettingFactor() {
        return this.lambda;
    }

    public void setForgettingFactor(double lambda) {
        if (lambda <= 0.0 || lambda > 1.0) {
            throw new IllegalArgumentException("The forgetting factor must be in (0, 1]");
        }
        this.lambda = lambda;
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        public RLS train(double[][] x, double[] y) {
            return new RLS(x, y);
        }
    }
}

