/*
 * Decompiled with CFR 0.152.
 */
package smile.math.matrix;

import java.util.Arrays;
import smile.math.Math;
import smile.math.matrix.Matrix;
import smile.math.matrix.MatrixMultiplication;

public class SparseMatrix
implements Matrix,
MatrixMultiplication<SparseMatrix, SparseMatrix> {
    private static final long serialVersionUID = 1L;
    private int nrows;
    private int ncols;
    private int[] colIndex;
    private int[] rowIndex;
    private double[] x;
    private boolean symmetric = false;

    private SparseMatrix(int nrows, int ncols, int nvals) {
        this.nrows = nrows;
        this.ncols = ncols;
        this.rowIndex = new int[nvals];
        this.colIndex = new int[ncols + 1];
        this.x = new double[nvals];
    }

    public SparseMatrix(int nrows, int ncols, double[] x, int[] rowIndex, int[] colIndex) {
        this.nrows = nrows;
        this.ncols = ncols;
        this.rowIndex = rowIndex;
        this.colIndex = colIndex;
        this.x = x;
    }

    public SparseMatrix(double[][] D) {
        this(D, 100.0 * Math.EPSILON);
    }

    public SparseMatrix(double[][] D, double tol) {
        this.nrows = D.length;
        this.ncols = D[0].length;
        int n = 0;
        for (int i = 0; i < this.nrows; ++i) {
            for (int j = 0; j < this.ncols; ++j) {
                if (!(Math.abs(D[i][j]) >= tol)) continue;
                ++n;
            }
        }
        this.x = new double[n];
        this.rowIndex = new int[n];
        this.colIndex = new int[this.ncols + 1];
        this.colIndex[this.ncols] = n;
        n = 0;
        for (int j = 0; j < this.ncols; ++j) {
            this.colIndex[j] = n;
            for (int i = 0; i < this.nrows; ++i) {
                if (!(Math.abs(D[i][j]) >= tol)) continue;
                this.rowIndex[n] = i;
                this.x[n] = D[i][j];
                ++n;
            }
        }
    }

    @Override
    public boolean isSymmetric() {
        return this.symmetric;
    }

    @Override
    public void setSymmetric(boolean symmetric) {
        this.symmetric = symmetric;
    }

    @Override
    public int nrows() {
        return this.nrows;
    }

    @Override
    public int ncols() {
        return this.ncols;
    }

    public int size() {
        return this.colIndex[this.ncols];
    }

    public double[] values() {
        return this.x;
    }

    @Override
    public double get(int i, int j) {
        if (i < 0 || i >= this.nrows || j < 0 || j >= this.ncols) {
            throw new IllegalArgumentException("Invalid index: i = " + i + " j = " + j);
        }
        for (int k = this.colIndex[j]; k < this.colIndex[j + 1]; ++k) {
            if (this.rowIndex[k] != i) continue;
            return this.x[k];
        }
        return 0.0;
    }

    @Override
    public double[] ax(double[] x, double[] y) {
        Arrays.fill(y, 0.0);
        for (int j = 0; j < this.ncols; ++j) {
            for (int i = this.colIndex[j]; i < this.colIndex[j + 1]; ++i) {
                int n = this.rowIndex[i];
                y[n] = y[n] + this.x[i] * x[j];
            }
        }
        return y;
    }

    @Override
    public double[] axpy(double[] x, double[] y) {
        for (int j = 0; j < this.ncols; ++j) {
            for (int i = this.colIndex[j]; i < this.colIndex[j + 1]; ++i) {
                int n = this.rowIndex[i];
                y[n] = y[n] + this.x[i] * x[j];
            }
        }
        return y;
    }

    @Override
    public double[] axpy(double[] x, double[] y, double b) {
        int i = 0;
        while (i < y.length) {
            int n = i++;
            y[n] = y[n] * b;
        }
        for (int j = 0; j < this.ncols; ++j) {
            for (int i2 = this.colIndex[j]; i2 < this.colIndex[j + 1]; ++i2) {
                int n = this.rowIndex[i2];
                y[n] = y[n] + this.x[i2] * x[j];
            }
        }
        return y;
    }

    @Override
    public double[] atx(double[] x, double[] y) {
        Arrays.fill(y, 0.0);
        for (int i = 0; i < this.ncols; ++i) {
            for (int j = this.colIndex[i]; j < this.colIndex[i + 1]; ++j) {
                int n = i;
                y[n] = y[n] + this.x[j] * x[this.rowIndex[j]];
            }
        }
        return y;
    }

    @Override
    public double[] atxpy(double[] x, double[] y) {
        for (int i = 0; i < this.ncols; ++i) {
            for (int j = this.colIndex[i]; j < this.colIndex[i + 1]; ++j) {
                int n = i;
                y[n] = y[n] + this.x[j] * x[this.rowIndex[j]];
            }
        }
        return y;
    }

    @Override
    public double[] atxpy(double[] x, double[] y, double b) {
        for (int i = 0; i < this.ncols; ++i) {
            int n = i;
            y[n] = y[n] * b;
            for (int j = this.colIndex[i]; j < this.colIndex[i + 1]; ++j) {
                int n2 = i;
                y[n2] = y[n2] + this.x[j] * x[this.rowIndex[j]];
            }
        }
        return y;
    }

    @Override
    public SparseMatrix transpose() {
        int k;
        int j;
        int i;
        int m = this.nrows;
        int n = this.ncols;
        SparseMatrix at = new SparseMatrix(n, m, this.x.length);
        int[] count = new int[m];
        for (i = 0; i < n; ++i) {
            for (j = this.colIndex[i]; j < this.colIndex[i + 1]; ++j) {
                int n2 = k = this.rowIndex[j];
                count[n2] = count[n2] + 1;
            }
        }
        for (int j2 = 0; j2 < m; ++j2) {
            at.colIndex[j2 + 1] = at.colIndex[j2] + count[j2];
        }
        Arrays.fill(count, 0);
        for (i = 0; i < n; ++i) {
            for (j = this.colIndex[i]; j < this.colIndex[i + 1]; ++j) {
                k = this.rowIndex[j];
                int index = at.colIndex[k] + count[k];
                at.rowIndex[index] = i;
                at.x[index] = this.x[j];
                int n3 = k;
                count[n3] = count[n3] + 1;
            }
        }
        return at;
    }

    @Override
    public SparseMatrix abmm(SparseMatrix B) {
        if (this.ncols != B.nrows) {
            throw new IllegalArgumentException(String.format("Matrix dimensions do not match for matrix multiplication: %d x %d vs %d x %d", this.nrows(), this.ncols(), B.nrows(), B.ncols()));
        }
        int m = this.nrows;
        int anz = this.size();
        int n = B.ncols;
        int[] Bp = B.colIndex;
        int[] Bi = B.rowIndex;
        double[] Bx = B.x;
        int bnz = Bp[n];
        int[] w = new int[m];
        double[] abj = new double[m];
        int nzmax = Math.max(anz + bnz, m);
        SparseMatrix C = new SparseMatrix(m, n, nzmax);
        int[] Cp = C.colIndex;
        int[] Ci = C.rowIndex;
        double[] Cx = C.x;
        int nz = 0;
        for (int j = 0; j < n; ++j) {
            int p;
            if (nz + m > nzmax) {
                nzmax = 2 * nzmax + m;
                double[] Cx2 = new double[nzmax];
                int[] Ci2 = new int[nzmax];
                System.arraycopy(Ci, 0, Ci2, 0, nz);
                System.arraycopy(Cx, 0, Cx2, 0, nz);
                Ci = Ci2;
                Cx = Cx2;
                C.rowIndex = Ci;
                C.x = Cx;
            }
            Cp[j] = nz;
            for (p = Bp[j]; p < Bp[j + 1]; ++p) {
                nz = SparseMatrix.scatter(this, Bi[p], Bx[p], w, abj, j + 1, C, nz);
            }
            for (p = Cp[j]; p < nz; ++p) {
                Cx[p] = abj[Ci[p]];
            }
        }
        Cp[n] = nz;
        return C;
    }

    private static int scatter(SparseMatrix A, int j, double beta, int[] w, double[] x, int mark, SparseMatrix C, int nz) {
        int[] Ap = A.colIndex;
        int[] Ai = A.rowIndex;
        double[] Ax = A.x;
        int[] Ci = C.rowIndex;
        for (int p = Ap[j]; p < Ap[j + 1]; ++p) {
            int i = Ai[p];
            if (w[i] < mark) {
                w[i] = mark;
                Ci[nz++] = i;
                x[i] = beta * Ax[p];
                continue;
            }
            int n = i;
            x[n] = x[n] + beta * Ax[p];
        }
        return nz;
    }

    @Override
    public SparseMatrix abtmm(SparseMatrix B) {
        SparseMatrix BT = B.transpose();
        return this.abmm(BT);
    }

    @Override
    public SparseMatrix atbmm(SparseMatrix B) {
        SparseMatrix AT = this.transpose();
        return AT.abmm(B);
    }

    @Override
    public SparseMatrix ata() {
        SparseMatrix AT = this.transpose();
        return AT.aat(this);
    }

    @Override
    public SparseMatrix aat() {
        SparseMatrix AT = this.transpose();
        return this.aat(AT);
    }

    private SparseMatrix aat(SparseMatrix AT) {
        int i;
        int j;
        int i2;
        int m = this.nrows;
        int[] done = new int[m];
        for (int i3 = 0; i3 < m; ++i3) {
            done[i3] = -1;
        }
        int nvals = 0;
        for (int j2 = 0; j2 < m; ++j2) {
            for (i2 = AT.colIndex[j2]; i2 < AT.colIndex[j2 + 1]; ++i2) {
                int k = AT.rowIndex[i2];
                for (int l = this.colIndex[k]; l < this.colIndex[k + 1]; ++l) {
                    int h = this.rowIndex[l];
                    if (done[h] == j2) continue;
                    done[h] = j2;
                    ++nvals;
                }
            }
        }
        SparseMatrix aat = new SparseMatrix(m, m, nvals);
        nvals = 0;
        for (i2 = 0; i2 < m; ++i2) {
            done[i2] = -1;
        }
        for (j = 0; j < m; ++j) {
            aat.colIndex[j] = nvals;
            for (i = AT.colIndex[j]; i < AT.colIndex[j + 1]; ++i) {
                int k = AT.rowIndex[i];
                for (int l = this.colIndex[k]; l < this.colIndex[k + 1]; ++l) {
                    int h = this.rowIndex[l];
                    if (done[h] == j) continue;
                    done[h] = j;
                    aat.rowIndex[nvals] = h;
                    ++nvals;
                }
            }
        }
        aat.colIndex[m] = nvals;
        for (j = 0; j < m; ++j) {
            if (aat.colIndex[j + 1] - aat.colIndex[j] <= 1) continue;
            Arrays.sort(aat.rowIndex, aat.colIndex[j], aat.colIndex[j + 1]);
        }
        double[] temp = new double[m];
        for (i = 0; i < m; ++i) {
            int k;
            int j3;
            for (j3 = AT.colIndex[i]; j3 < AT.colIndex[i + 1]; ++j3) {
                k = AT.rowIndex[j3];
                for (int l = this.colIndex[k]; l < this.colIndex[k + 1]; ++l) {
                    int h;
                    int n = h = this.rowIndex[l];
                    temp[n] = temp[n] + AT.x[j3] * this.x[l];
                }
            }
            for (j3 = aat.colIndex[i]; j3 < aat.colIndex[i + 1]; ++j3) {
                k = aat.rowIndex[j3];
                aat.x[j3] = temp[k];
                temp[k] = 0.0;
            }
        }
        return aat;
    }

    @Override
    public double[] diag() {
        int n = Math.min(this.nrows(), this.ncols());
        double[] d = new double[n];
        block0: for (int i = 0; i < n; ++i) {
            for (int j = this.colIndex[i]; j < this.colIndex[i + 1]; ++j) {
                if (this.rowIndex[j] != i) continue;
                d[i] = this.x[j];
                continue block0;
            }
        }
        return d;
    }
}

