/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

import java.util.HashMap;
import java.util.Map;
import smile.math.Math;
import smile.sequence.SequenceLabeler;

public class HMM<O>
implements SequenceLabeler<O> {
    private int numStates;
    private int numSymbols;
    private double[] pi;
    private double[][] a;
    private double[][] b;
    private Map<O, Integer> symbols;

    private HMM(int numStates, int numSymbols) {
        if (numStates <= 0) {
            throw new IllegalArgumentException("Invalid number of states: " + numStates);
        }
        if (numSymbols <= 0) {
            throw new IllegalArgumentException("Invalid number of emission symbols: " + numSymbols);
        }
        this.numStates = numStates;
        this.numSymbols = numSymbols;
        this.pi = new double[numStates];
        this.a = new double[numStates][numStates];
        this.b = new double[numStates][numSymbols];
    }

    public HMM(double[] pi, double[][] a, double[][] b) {
        this(pi, a, b, null);
    }

    public HMM(double[] pi, double[][] a, double[][] b, O[] symbols) {
        int j;
        double sum;
        int i;
        if (pi.length == 0) {
            throw new IllegalArgumentException("Invalid initial state probabilities.");
        }
        if (pi.length != a.length) {
            throw new IllegalArgumentException("Invalid state transition probability matrix.");
        }
        if (a.length != b.length) {
            throw new IllegalArgumentException("Invalid symbol emission probability matrix.");
        }
        if (symbols != null) {
            if (b[0].length != symbols.length) {
                throw new IllegalArgumentException("Invalid size of emission symbol list.");
            }
            this.symbols = new HashMap<O, Integer>();
            for (i = 0; i < symbols.length; ++i) {
                this.symbols.put(symbols[i], i);
            }
        }
        this.numStates = pi.length;
        this.numSymbols = b[0].length;
        for (i = 0; i < this.numStates; ++i) {
            if (a[i].length != this.numStates) {
                throw new IllegalArgumentException("Invalid state transition probability matrix.");
            }
            sum = 0.0;
            for (j = 0; j < this.numStates; ++j) {
                if (a[i][j] < 0.0 || a[i][j] > 1.0) {
                    throw new IllegalArgumentException("Invalid state transition probability: " + a[i][j]);
                }
                sum += a[i][j];
            }
            if (!(Math.abs((double)(1.0 - sum)) > 1.0E-7)) continue;
            throw new IllegalArgumentException(String.format("The row %d of state transition probability matrix doesn't sum to 1.", i));
        }
        for (i = 0; i < this.numStates; ++i) {
            if (b[i].length != this.numSymbols) {
                throw new IllegalArgumentException("Invalid symbol emission probability matrix.");
            }
            sum = 0.0;
            for (j = 0; j < this.numSymbols; ++j) {
                if (b[i][j] < 0.0 || b[i][j] > 1.0) {
                    throw new IllegalArgumentException("Invalid symbol emission probability: " + b[i][j]);
                }
                sum += b[i][j];
            }
            if (!(Math.abs((double)(1.0 - sum)) > 1.0E-7)) continue;
            throw new IllegalArgumentException(String.format("The row %d of symbol emission probability matrix doesn't sum to 1.", i));
        }
        this.pi = pi;
        this.a = a;
        this.b = b;
    }

    public int numStates() {
        return this.numStates;
    }

    public int numSymbols() {
        return this.numSymbols;
    }

    public double[] getInitialStateProbabilities() {
        return this.pi;
    }

    public double[][] getStateTransitionProbabilities() {
        return this.a;
    }

    public double[][] getSymbolEmissionProbabilities() {
        return this.b;
    }

    private static double log(double x) {
        double y = x < 1.0E-300 ? -690.7755 : java.lang.Math.log(x);
        return y;
    }

    public double p(int[] o, int[] s) {
        return Math.exp((double)this.logp(o, s));
    }

    public double logp(int[] o, int[] s) {
        if (o.length != s.length) {
            throw new IllegalArgumentException("The observation sequence and state sequence are not the same length.");
        }
        int n = s.length;
        double p = HMM.log(this.pi[s[0]]) + HMM.log(this.b[s[0]][o[0]]);
        for (int i = 1; i < n; ++i) {
            p += HMM.log(this.a[s[i - 1]][s[i]]) + HMM.log(this.b[s[i]][o[i]]);
        }
        return p;
    }

    public double p(int[] o) {
        return Math.exp((double)this.logp(o));
    }

    public double logp(int[] o) {
        double[][] alpha = new double[o.length][this.numStates];
        double[] scaling = new double[o.length];
        this.forward(o, alpha, scaling);
        double p = 0.0;
        for (int t = 0; t < o.length; ++t) {
            p += java.lang.Math.log(scaling[t]);
        }
        return p;
    }

    private void scale(double[] scaling, double[][] alpha, int t) {
        int i;
        double[] table = alpha[t];
        double sum = 0.0;
        for (i = 0; i < table.length; ++i) {
            sum += table[i];
        }
        scaling[t] = sum;
        i = 0;
        while (i < table.length) {
            int n = i++;
            table[n] = table[n] / sum;
        }
    }

    private void forward(int[] o, double[][] alpha, double[] scaling) {
        for (int k = 0; k < this.numStates; ++k) {
            alpha[0][k] = this.pi[k] * this.b[k][o[0]];
        }
        this.scale(scaling, alpha, 0);
        for (int t = 1; t < o.length; ++t) {
            for (int k = 0; k < this.numStates; ++k) {
                double sum = 0.0;
                for (int i = 0; i < this.numStates; ++i) {
                    sum += alpha[t - 1][i] * this.a[i][k];
                }
                alpha[t][k] = sum * this.b[k][o[t]];
            }
            this.scale(scaling, alpha, t);
        }
    }

    private void backward(int[] o, double[][] beta, double[] scaling) {
        int n = o.length - 1;
        for (int i = 0; i < this.numStates; ++i) {
            beta[n][i] = 1.0 / scaling[n];
        }
        int t = n;
        while (t-- > 0) {
            for (int i = 0; i < this.numStates; ++i) {
                double sum = 0.0;
                for (int j = 0; j < this.numStates(); ++j) {
                    sum += beta[t + 1][j] * this.a[i][j] * this.b[j][o[t + 1]];
                }
                beta[t][i] = sum / scaling[t];
            }
        }
    }

    public int[] predict(int[] o) {
        double[][] trellis = new double[o.length][this.numStates];
        int[][] psy = new int[o.length][this.numStates];
        int[] s = new int[o.length];
        for (int i = 0; i < this.numStates; ++i) {
            trellis[0][i] = HMM.log(this.pi[i]) + HMM.log(this.b[i][o[0]]);
            psy[0][i] = 0;
        }
        for (int t = 1; t < o.length; ++t) {
            for (int j = 0; j < this.numStates; ++j) {
                double maxDelta = Double.NEGATIVE_INFINITY;
                int maxPsy = 0;
                for (int i = 0; i < this.numStates; ++i) {
                    double delta = trellis[t - 1][i] + HMM.log(this.a[i][j]);
                    if (!(maxDelta < delta)) continue;
                    maxDelta = delta;
                    maxPsy = i;
                }
                trellis[t][j] = maxDelta + HMM.log(this.b[j][o[t]]);
                psy[t][j] = maxPsy;
            }
        }
        int n = o.length - 1;
        double maxDelta = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.numStates; ++i) {
            if (!(maxDelta < trellis[n][i])) continue;
            maxDelta = trellis[n][i];
            s[n] = i;
        }
        int t = n;
        while (t-- > 0) {
            s[t] = psy[t + 1][s[t + 1]];
        }
        return s;
    }

    public HMM(int[][] observations, int[][] labels) {
        int i;
        if (observations.length != labels.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }
        this.numStates = 0;
        this.numSymbols = 0;
        for (i = 0; i < observations.length; ++i) {
            if (observations[i].length != labels[i].length) {
                throw new IllegalArgumentException(String.format("The length of observation sequence %d and that of corresponding label sequence are different.", i));
            }
            this.numStates = Math.max((int)this.numStates, (int)(Math.max((int[])labels[i]) + 1));
            this.numSymbols = Math.max((int)this.numSymbols, (int)(Math.max((int[])observations[i]) + 1));
        }
        this.pi = new double[this.numStates];
        this.a = new double[this.numStates][this.numStates];
        this.b = new double[this.numStates][this.numSymbols];
        for (i = 0; i < observations.length; ++i) {
            int n = labels[i][0];
            this.pi[n] = this.pi[n] + 1.0;
            double[] dArray = this.b[labels[i][0]];
            int n2 = observations[i][0];
            dArray[n2] = dArray[n2] + 1.0;
            for (int j = 1; j < observations[i].length; ++j) {
                double[] dArray2 = this.a[labels[i][j - 1]];
                int n3 = labels[i][j];
                dArray2[n3] = dArray2[n3] + 1.0;
                double[] dArray3 = this.b[labels[i][j]];
                int n4 = observations[i][j];
                dArray3[n4] = dArray3[n4] + 1.0;
            }
        }
        Math.unitize1((double[])this.pi);
        for (i = 0; i < this.numStates; ++i) {
            Math.unitize1((double[])this.a[i]);
            Math.unitize1((double[])this.b[i]);
        }
    }

    public HMM(O[][] observations, int[][] labels) {
        int i;
        if (observations.length != labels.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }
        int index = 0;
        this.symbols = new HashMap<O, Integer>();
        for (int i2 = 0; i2 < observations.length; ++i2) {
            if (observations[i2].length != labels[i2].length) {
                throw new IllegalArgumentException(String.format("The length of observation sequence %d and that of corresponding label sequence are different.", i2));
            }
            for (int j = 0; j < observations[i2].length; ++j) {
                Integer sym = this.symbols.get(observations[i2][j]);
                if (sym != null) continue;
                this.symbols.put(observations[i2][j], index++);
            }
        }
        int[][] obs = new int[observations.length][];
        for (i = 0; i < obs.length; ++i) {
            obs[i] = HMM.translate(this.symbols, observations[i]);
        }
        this.numStates = 0;
        this.numSymbols = 0;
        for (i = 0; i < obs.length; ++i) {
            this.numStates = Math.max((int)this.numStates, (int)(Math.max((int[])labels[i]) + 1));
            this.numSymbols = Math.max((int)this.numSymbols, (int)(Math.max((int[])obs[i]) + 1));
        }
        this.pi = new double[this.numStates];
        this.a = new double[this.numStates][this.numStates];
        this.b = new double[this.numStates][this.numSymbols];
        for (i = 0; i < obs.length; ++i) {
            int n = labels[i][0];
            this.pi[n] = this.pi[n] + 1.0;
            double[] dArray = this.b[labels[i][0]];
            int n2 = obs[i][0];
            dArray[n2] = dArray[n2] + 1.0;
            for (int j = 1; j < obs[i].length; ++j) {
                double[] dArray2 = this.a[labels[i][j - 1]];
                int n3 = labels[i][j];
                dArray2[n3] = dArray2[n3] + 1.0;
                double[] dArray3 = this.b[labels[i][j]];
                int n4 = obs[i][j];
                dArray3[n4] = dArray3[n4] + 1.0;
            }
        }
        Math.unitize1((double[])this.pi);
        for (i = 0; i < this.numStates; ++i) {
            Math.unitize1((double[])this.a[i]);
            Math.unitize1((double[])this.b[i]);
        }
    }

    public HMM<O> learn(O[][] observations, int iterations) {
        int[][] obs = new int[observations.length][];
        for (int i = 0; i < obs.length; ++i) {
            obs[i] = this.translate(observations[i]);
        }
        return this.learn(obs, iterations);
    }

    public HMM<O> learn(int[][] observations, int iterations) {
        HMM<O> hmm = this;
        for (int iter = 0; iter < iterations; ++iter) {
            hmm = hmm.iterate(observations);
        }
        return hmm;
    }

    private HMM<O> iterate(int[][] sequences) {
        int i;
        HMM<O> hmm = new HMM<O>(this.numStates, this.numSymbols);
        hmm.symbols = this.symbols;
        double[][][] gamma = new double[sequences.length][][];
        double[][] aijNum = new double[this.numStates][this.numStates];
        double[] aijDen = new double[this.numStates];
        for (int k = 0; k < sequences.length; ++k) {
            if (sequences[k].length <= 2) {
                throw new IllegalArgumentException(String.format("Traning sequence %d is too short.", k));
            }
            int[] o = sequences[k];
            double[][] alpha = new double[o.length][this.numStates];
            double[][] beta = new double[o.length][this.numStates];
            double[] scaling = new double[o.length];
            this.forward(o, alpha, scaling);
            this.backward(o, beta, scaling);
            double[][][] xi = this.estimateXi(o, alpha, beta);
            gamma[k] = this.estimateGamma(xi);
            double[][] g = gamma[k];
            int n = o.length - 1;
            for (int i2 = 0; i2 < this.numStates; ++i2) {
                for (int t = 0; t < n; ++t) {
                    int n2 = i2;
                    aijDen[n2] = aijDen[n2] + g[t][i2];
                    for (int j = 0; j < this.numStates; ++j) {
                        double[] dArray = aijNum[i2];
                        int n3 = j;
                        dArray[n3] = dArray[n3] + xi[t][i2][j];
                    }
                }
            }
        }
        for (i = 0; i < this.numStates; ++i) {
            if (aijDen[i] == 0.0) {
                System.arraycopy(this.a[i], 0, hmm.a[i], 0, this.numStates);
                continue;
            }
            for (int j = 0; j < this.numStates; ++j) {
                hmm.a[i][j] = aijNum[i][j] / aijDen[i];
            }
        }
        for (int j = 0; j < sequences.length; ++j) {
            for (int i3 = 0; i3 < this.numStates; ++i3) {
                int n = i3;
                hmm.pi[n] = hmm.pi[n] + gamma[j][0][i3];
            }
        }
        i = 0;
        while (i < this.numStates) {
            int n = i++;
            hmm.pi[n] = hmm.pi[n] / (double)sequences.length;
        }
        for (i = 0; i < this.numStates; ++i) {
            int j;
            double sum = 0.0;
            for (j = 0; j < sequences.length; ++j) {
                int[] o = sequences[j];
                for (int t = 0; t < o.length; ++t) {
                    double[] dArray = hmm.b[i];
                    int n = o[t];
                    dArray[n] = dArray[n] + gamma[j][t][i];
                    sum += gamma[j][t][i];
                }
            }
            j = 0;
            while (j < this.numSymbols) {
                double[] dArray = hmm.b[i];
                int n = j++;
                dArray[n] = dArray[n] / sum;
            }
        }
        return hmm;
    }

    private double[][][] estimateXi(int[] o, double[][] alpha, double[][] beta) {
        if (o.length <= 1) {
            throw new IllegalArgumentException("Observation sequence is too short.");
        }
        int n = o.length - 1;
        double[][][] xi = new double[n][this.numStates][this.numStates];
        for (int t = 0; t < n; ++t) {
            for (int i = 0; i < this.numStates; ++i) {
                for (int j = 0; j < this.numStates; ++j) {
                    xi[t][i][j] = alpha[t][i] * this.a[i][j] * this.b[j][o[t + 1]] * beta[t + 1][j];
                }
            }
        }
        return xi;
    }

    private double[][] estimateGamma(double[][][] xi) {
        double[][] gamma = new double[xi.length + 1][this.numStates];
        for (int t = 0; t < xi.length; ++t) {
            for (int i = 0; i < this.numStates; ++i) {
                for (int j = 0; j < this.numStates; ++j) {
                    double[] dArray = gamma[t];
                    int n = i;
                    dArray[n] = dArray[n] + xi[t][i][j];
                }
            }
        }
        int n = xi.length - 1;
        for (int j = 0; j < this.numStates; ++j) {
            for (int i = 0; i < this.numStates; ++i) {
                double[] dArray = gamma[xi.length];
                int n2 = j;
                dArray[n2] = dArray[n2] + xi[n][i][j];
            }
        }
        return gamma;
    }

    public String toString() {
        int j;
        int i;
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("HMM (%d states, %d emission symbols)%n", this.numStates, this.numSymbols));
        sb.append("\tInitial state probability:\n\t\t");
        for (i = 0; i < this.numStates; ++i) {
            sb.append(String.format("%.4g ", this.pi[i]));
        }
        sb.append("\n\tState transition probability:");
        for (i = 0; i < this.numStates; ++i) {
            sb.append("\n\t\t");
            for (j = 0; j < this.numStates; ++j) {
                sb.append(String.format("%.4g ", this.a[i][j]));
            }
        }
        sb.append("\n\tSymbol emission probability:");
        for (i = 0; i < this.numStates; ++i) {
            sb.append("\n\t\t");
            for (j = 0; j < this.numSymbols; ++j) {
                sb.append(String.format("%.4g ", this.b[i][j]));
            }
        }
        return sb.toString();
    }

    private int[] translate(O[] o) {
        return HMM.translate(this.symbols, o);
    }

    private static <O> int[] translate(Map<O, Integer> symbols, O[] o) {
        if (symbols == null) {
            throw new IllegalArgumentException("No availabe emission symbol list.");
        }
        int[] seq = new int[o.length];
        for (int i = 0; i < o.length; ++i) {
            Integer sym = symbols.get(o[i]);
            if (sym == null) {
                throw new IllegalArgumentException("Invalid observation symbol: " + o[i]);
            }
            seq[i] = sym;
        }
        return seq;
    }

    public double p(O[] o, int[] s) {
        return Math.exp((double)this.logp(o, s));
    }

    public double logp(O[] o, int[] s) {
        return this.logp(this.translate(o), s);
    }

    public double p(O[] o) {
        return Math.exp((double)this.logp(o));
    }

    public double logp(O[] o) {
        return this.logp(this.translate(o));
    }

    @Override
    public int[] predict(O[] o) {
        return this.predict(this.translate(o));
    }
}

