/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.math.Math;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;
import smile.projection.Projection;

public class FLD
implements Classifier<double[]>,
Projection<double[]> {
    private static final long serialVersionUID = 1L;
    private final int p;
    private final int k;
    private final double[] mean;
    private final DenseMatrix mu;
    private final DenseMatrix scaling;
    private final double[] smean;
    private final double[][] smu;

    public FLD(double[][] x, int[] y) {
        this(x, y, -1);
    }

    public FLD(double[][] x, int[] y, int L) {
        this(x, y, L, 1.0E-4);
    }

    public FLD(double[][] x, int[] y, int L, double tol) {
        int l;
        int j;
        int j2;
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int[] labels = Math.unique((int[])y);
        Arrays.sort(labels);
        for (int i2 = 0; i2 < labels.length; ++i2) {
            if (labels[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i2]);
            }
            if (i2 <= 0 || labels[i2] - labels[i2 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + (labels[i2 - 1] + 1));
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (tol < 0.0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }
        if (x.length <= this.k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", x.length, this.k));
        }
        if (L >= this.k) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", L, this.k));
        }
        if (L <= 0) {
            L = this.k - 1;
        }
        int n = x.length;
        this.p = x[0].length;
        int[] ni = new int[this.k];
        this.mean = Math.colMeans((double[][])x);
        DenseMatrix T = Matrix.zeros((int)this.p, (int)this.p);
        this.mu = Matrix.zeros((int)this.k, (int)this.p);
        for (i = 0; i < n; ++i) {
            int c;
            int n2 = c = y[i];
            ni[n2] = ni[n2] + 1;
            for (j2 = 0; j2 < this.p; ++j2) {
                this.mu.add(c, j2, x[i][j2]);
            }
        }
        for (i = 0; i < this.k; ++i) {
            for (j = 0; j < this.p; ++j) {
                this.mu.div(i, j, (double)ni[i]);
                this.mu.sub(i, j, this.mean[j]);
            }
        }
        for (i = 0; i < n; ++i) {
            for (j = 0; j < this.p; ++j) {
                for (l = 0; l <= j; ++l) {
                    T.add(j, l, (x[i][j] - this.mean[j]) * (x[i][l] - this.mean[l]));
                }
            }
        }
        for (int j3 = 0; j3 < this.p; ++j3) {
            for (int l2 = 0; l2 <= j3; ++l2) {
                T.div(j3, l2, (double)n);
                T.set(l2, j3, T.get(j3, l2));
            }
        }
        DenseMatrix B = Matrix.zeros((int)this.p, (int)this.p);
        for (int i3 = 0; i3 < this.k; ++i3) {
            for (j2 = 0; j2 < this.p; ++j2) {
                for (int l3 = 0; l3 <= j2; ++l3) {
                    B.add(j2, l3, this.mu.get(i3, j2) * this.mu.get(i3, l3));
                }
            }
        }
        for (j = 0; j < this.p; ++j) {
            for (l = 0; l <= j; ++l) {
                B.div(j, l, (double)this.k);
                B.set(l, j, B.get(j, l));
            }
        }
        T.setSymmetric(true);
        EVD eigen = T.eigen();
        tol *= tol;
        double[] s = eigen.getEigenValues();
        for (int i4 = 0; i4 < s.length; ++i4) {
            if (s[i4] < tol) {
                throw new IllegalArgumentException("The covariance matrix is close to singular.");
            }
            s[i4] = 1.0 / s[i4];
        }
        DenseMatrix U = eigen.getEigenVectors();
        DenseMatrix UB = (DenseMatrix)U.atbmm((Object)B);
        for (int i5 = 0; i5 < this.k; ++i5) {
            for (int j4 = 0; j4 < this.p; ++j4) {
                UB.mul(i5, j4, s[j4]);
            }
        }
        B = (DenseMatrix)U.abmm((Object)UB);
        B.setSymmetric(true);
        eigen = B.eigen();
        U = eigen.getEigenVectors();
        this.scaling = Matrix.zeros((int)this.p, (int)L);
        for (int j5 = 0; j5 < L; ++j5) {
            for (int i6 = 0; i6 < this.p; ++i6) {
                this.scaling.set(i6, j5, U.get(i6, j5));
            }
        }
        this.smean = new double[L];
        this.scaling.atx(this.mean, this.smean);
        this.smu = ((DenseMatrix)this.mu.abmm((Object)this.scaling)).array();
    }

    @Override
    public int predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] wx = this.project(x);
        int y = 0;
        double nearest = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double d = Math.distance((double[])wx, (double[])this.smu[i]);
            if (!(d < nearest)) continue;
            nearest = d;
            y = i;
        }
        return y;
    }

    @Override
    public double[] project(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] y = new double[this.scaling.ncols()];
        this.scaling.atx(x, y);
        Math.minus((double[])y, (double[])this.smean);
        return y;
    }

    public double[][] project(double[][] x) {
        double[][] y = new double[x.length][this.scaling.ncols()];
        for (int i = 0; i < x.length; ++i) {
            if (x[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x[i].length, this.p));
            }
            this.scaling.atx(x[i], y[i]);
            Math.minus((double[])y[i], (double[])this.smean);
        }
        return y;
    }

    public DenseMatrix getProjection() {
        return this.scaling;
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private int L = -1;
        private double tol = 1.0E-4;

        public Trainer setDimension(int L) {
            if (L < 1) {
                throw new IllegalArgumentException("Invalid mapping space dimension: " + L);
            }
            this.L = L;
            return this;
        }

        public Trainer setTolerance(double tol) {
            if (tol < 0.0) {
                throw new IllegalArgumentException("Invalid tol: " + tol);
            }
            this.tol = tol;
            return this;
        }

        public FLD train(double[][] x, int[] y) {
            return new FLD(x, y, this.L, this.tol);
        }
    }
}

