/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import smile.classification.ClassifierTrainer;
import smile.classification.SoftClassifier;
import smile.math.Math;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;

public class LDA
implements SoftClassifier<double[]> {
    private static final long serialVersionUID = 1L;
    private final int p;
    private final int k;
    private final double[] ct;
    private final double[] priori;
    private final double[][] mu;
    private final DenseMatrix scaling;
    private final double[] eigen;

    public LDA(double[][] x, int[] y) {
        this(x, y, null);
    }

    public LDA(double[][] x, int[] y, double[] priori) {
        this(x, y, priori, 1.0E-4);
    }

    public LDA(double[][] x, int[] y, double tol) {
        this(x, y, null, tol);
    }

    public LDA(double[][] x, int[] y, double[] priori, double tol) {
        int j;
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (priori != null) {
            if (priori.length < 2) {
                throw new IllegalArgumentException("Invalid number of priori probabilities: " + priori.length);
            }
            double sum = 0.0;
            for (double pr : priori) {
                if (pr <= 0.0 || pr >= 1.0) {
                    throw new IllegalArgumentException("Invalid priori probability: " + pr);
                }
                sum += pr;
            }
            if (Math.abs((double)(sum - 1.0)) > 1.0E-10) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
            }
        }
        int[] labels = Math.unique((int[])y);
        Arrays.sort(labels);
        for (int i2 = 0; i2 < labels.length; ++i2) {
            if (labels[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i2]);
            }
            if (i2 <= 0 || labels[i2] - labels[i2 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + (labels[i2 - 1] + 1));
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (priori != null && this.k != priori.length) {
            throw new IllegalArgumentException("The number of classes and the number of priori probabilities don't match.");
        }
        if (tol < 0.0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }
        int n = x.length;
        if (n <= this.k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", n, this.k));
        }
        this.p = x[0].length;
        int[] ni = new int[this.k];
        double[] mean = Math.colMeans((double[][])x);
        DenseMatrix C = Matrix.zeros((int)this.p, (int)this.p);
        this.mu = new double[this.k][this.p];
        for (i = 0; i < n; ++i) {
            int c;
            int n2 = c = y[i];
            ni[n2] = ni[n2] + 1;
            for (int j2 = 0; j2 < this.p; ++j2) {
                double[] dArray = this.mu[c];
                int n3 = j2;
                dArray[n3] = dArray[n3] + x[i][j2];
            }
        }
        for (i = 0; i < this.k; ++i) {
            j = 0;
            while (j < this.p) {
                double[] dArray = this.mu[i];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)ni[i];
            }
        }
        if (priori == null) {
            priori = new double[this.k];
            for (i = 0; i < this.k; ++i) {
                priori[i] = (double)ni[i] / (double)n;
            }
        }
        this.priori = priori;
        this.ct = new double[this.k];
        for (i = 0; i < this.k; ++i) {
            this.ct[i] = Math.log((double)priori[i]);
        }
        for (i = 0; i < n; ++i) {
            for (j = 0; j < this.p; ++j) {
                for (int l = 0; l <= j; ++l) {
                    C.add(j, l, (x[i][j] - mean[j]) * (x[i][l] - mean[l]));
                }
            }
        }
        tol *= tol;
        for (int j3 = 0; j3 < this.p; ++j3) {
            for (int l = 0; l <= j3; ++l) {
                C.div(j3, l, (double)(n - this.k));
                C.set(l, j3, C.get(j3, l));
            }
            if (!(C.get(j3, j3) < tol)) continue;
            throw new IllegalArgumentException(String.format("Covariance matrix (variable %d) is close to singular.", j3));
        }
        C.setSymmetric(true);
        EVD evd = C.eigen();
        for (double s : evd.getEigenValues()) {
            if (!(s < tol)) continue;
            throw new IllegalArgumentException("The covariance matrix is close to singular.");
        }
        this.eigen = evd.getEigenValues();
        this.scaling = evd.getEigenVectors();
    }

    public double[] getPriori() {
        return this.priori;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, (double[])null);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        if (posteriori != null && posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        int y = 0;
        double max = Double.NEGATIVE_INFINITY;
        double[] d = new double[this.p];
        double[] ux = new double[this.p];
        for (int i = 0; i < this.k; ++i) {
            for (int j = 0; j < this.p; ++j) {
                d[j] = x[j] - this.mu[i][j];
            }
            this.scaling.atx(d, ux);
            double f = 0.0;
            for (int j = 0; j < this.p; ++j) {
                f += ux[j] * ux[j] / this.eigen[j];
            }
            if (max < (f = this.ct[i] - 0.5 * f)) {
                max = f;
                y = i;
            }
            if (posteriori == null) continue;
            posteriori[i] = f;
        }
        if (posteriori != null) {
            int i;
            double sum = 0.0;
            for (i = 0; i < this.k; ++i) {
                posteriori[i] = Math.exp((double)(posteriori[i] - max));
                sum += posteriori[i];
            }
            i = 0;
            while (i < this.k) {
                int n = i++;
                posteriori[n] = posteriori[n] / sum;
            }
        }
        return y;
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private double[] priori;
        private double tol = 1.0E-4;

        public Trainer setPriori(double[] priori) {
            this.priori = priori;
            return this;
        }

        public Trainer setTolerance(double tol) {
            if (tol < 0.0) {
                throw new IllegalArgumentException("Invalid tol: " + tol);
            }
            this.tol = tol;
            return this;
        }

        public LDA train(double[][] x, int[] y) {
            return new LDA(x, y, this.priori, this.tol);
        }
    }
}

