/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
import smile.classification.ClassifierTrainer;
import smile.classification.SoftClassifier;
import smile.data.Attribute;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.sort.QuickSort;
import smile.util.MulticoreExecutor;

public class DecisionTree
implements SoftClassifier<double[]> {
    private static final long serialVersionUID = 1L;
    private Attribute[] attributes;
    private double[] importance;
    private Node root;
    private SplitRule rule = SplitRule.GINI;
    private int k = 2;
    private int nodeSize = 1;
    private int maxNodes = 100;
    private int mtry;
    private transient int[][] order;

    private double impurity(int[] count, int n) {
        double impurity = 0.0;
        switch (this.rule) {
            case GINI: {
                impurity = 1.0;
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    double p = (double)count[i] / (double)n;
                    impurity -= p * p;
                }
                break;
            }
            case ENTROPY: {
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    double p = (double)count[i] / (double)n;
                    impurity -= p * Math.log2((double)p);
                }
                break;
            }
            case CLASSIFICATION_ERROR: {
                impurity = 0.0;
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    impurity = Math.max((double)impurity, (double)((double)count[i] / (double)n));
                }
                impurity = Math.abs((double)(1.0 - impurity));
            }
        }
        return impurity;
    }

    public DecisionTree(double[][] x, int[] y, int maxNodes) {
        this(null, x, y, maxNodes);
    }

    public DecisionTree(double[][] x, int[] y, int maxNodes, SplitRule rule) {
        this(null, x, y, maxNodes, 1, rule);
    }

    public DecisionTree(double[][] x, int[] y, int maxNodes, int nodeSize, SplitRule rule) {
        this(null, x, y, maxNodes, nodeSize, rule);
    }

    public DecisionTree(Attribute[] attributes, double[][] x, int[] y, int maxNodes) {
        this(attributes, x, y, maxNodes, SplitRule.GINI);
    }

    public DecisionTree(Attribute[] attributes, double[][] x, int[] y, int maxNodes, SplitRule rule) {
        this(attributes, x, y, maxNodes, 1, x[0].length, rule, null, null);
    }

    public DecisionTree(Attribute[] attributes, double[][] x, int[] y, int maxNodes, int nodeSize, SplitRule rule) {
        this(attributes, x, y, maxNodes, nodeSize, x[0].length, rule, null, null);
    }

    public DecisionTree(Attribute[] attributes, double[][] x, int[] y, int maxNodes, int nodeSize, int mtry, SplitRule rule, int[] samples, int[][] order) {
        TrainNode node;
        int i;
        int i2;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (mtry < 1 || mtry > x[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + mtry);
        }
        if (maxNodes < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes);
        }
        if (nodeSize < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
        }
        int[] labels = Math.unique((int[])y);
        Arrays.sort(labels);
        for (int i3 = 0; i3 < labels.length; ++i3) {
            if (labels[i3] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i3]);
            }
            if (labels[i3] == i3) continue;
            throw new IllegalArgumentException("Missing class: " + i3);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i4 = 0; i4 < p; ++i4) {
                attributes[i4] = new NumericAttribute("V" + (i4 + 1));
            }
        }
        this.attributes = attributes;
        this.mtry = mtry;
        this.nodeSize = nodeSize;
        this.maxNodes = maxNodes;
        this.rule = rule;
        this.importance = new double[attributes.length];
        if (order != null) {
            this.order = order;
        } else {
            int n = x.length;
            int p = x[0].length;
            double[] a = new double[n];
            this.order = new int[p][];
            for (int j = 0; j < p; ++j) {
                if (!(attributes[j] instanceof NumericAttribute)) continue;
                for (i2 = 0; i2 < n; ++i2) {
                    a[i2] = x[i2][j];
                }
                this.order[j] = QuickSort.sort((double[])a);
            }
        }
        PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
        int n = y.length;
        int[] count = new int[this.k];
        if (samples == null) {
            samples = new int[n];
            for (i = 0; i < n; ++i) {
                samples[i] = 1;
                int n2 = y[i];
                count[n2] = count[n2] + 1;
            }
        } else {
            for (i = 0; i < n; ++i) {
                int n3 = y[i];
                count[n3] = count[n3] + samples[i];
            }
        }
        double[] posteriori = new double[this.k];
        for (i2 = 0; i2 < this.k; ++i2) {
            posteriori[i2] = (double)count[i2] / (double)n;
        }
        this.root = new Node(Math.whichMax((int[])count), posteriori);
        TrainNode trainRoot = new TrainNode(this.root, x, y, samples);
        if (trainRoot.findBestSplit()) {
            nextSplits.add(trainRoot);
        }
        for (int leaves = 1; leaves < this.maxNodes && (node = (TrainNode)nextSplits.poll()) != null; ++leaves) {
            node.split(nextSplits);
        }
    }

    public double[] importance() {
        return this.importance;
    }

    @Override
    public int predict(double[] x) {
        return this.root.predict(x);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        return this.root.predict(x, posteriori);
    }

    public int maxDepth() {
        return this.maxDepth(this.root);
    }

    private int maxDepth(Node node) {
        int rDepth;
        if (node == null) {
            return 0;
        }
        int lDepth = this.maxDepth(node.trueChild);
        if (lDepth > (rDepth = this.maxDepth(node.falseChild))) {
            return lDepth + 1;
        }
        return rDepth + 1;
    }

    public String dot() {
        StringBuilder builder = new StringBuilder();
        builder.append("digraph DecisionTree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
        int n = 0;
        LinkedList<DotNode> queue = new LinkedList<DotNode>();
        queue.add(new DotNode(-1, 0, this.root));
        while (!queue.isEmpty()) {
            DotNode dnode = (DotNode)queue.poll();
            int id = dnode.id;
            int parent = dnode.parent;
            Node node = dnode.node;
            if (node.trueChild == null && node.falseChild == null) {
                builder.append(String.format(" %d [label=<class = %d>, fillcolor=\"#00000000\", shape=ellipse];\n", id, node.output));
            } else {
                Attribute attr = this.attributes[node.splitFeature];
                if (attr.getType() == Attribute.Type.NOMINAL) {
                    builder.append(String.format(" %d [label=<%s = %s<br/>nscore = %.4f>, fillcolor=\"#00000000\"];\n", id, attr.getName(), attr.toString(node.splitValue), node.splitScore));
                } else if (attr.getType() == Attribute.Type.NUMERIC) {
                    builder.append(String.format(" %d [label=<%s &le; %.4f<br/>score = %.4f>, fillcolor=\"#00000000\"];\n", id, attr.getName(), node.splitValue, node.splitScore));
                } else {
                    throw new IllegalStateException("Unsupported attribute type: " + attr.getType());
                }
            }
            if (parent >= 0) {
                builder.append(' ').append(parent).append(" -> ").append(id);
                if (parent == 0) {
                    if (id == 1) {
                        builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                    } else {
                        builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                    }
                }
                builder.append(";\n");
            }
            if (node.trueChild != null) {
                queue.add(new DotNode(id, ++n, node.trueChild));
            }
            if (node.falseChild == null) continue;
            queue.add(new DotNode(id, ++n, node.falseChild));
        }
        builder.append("}");
        return builder.toString();
    }

    public Node getRoot() {
        return this.root;
    }

    private class DotNode {
        int parent;
        int id;
        Node node;

        DotNode(int parent, int id, Node node) {
            this.parent = parent;
            this.id = id;
            this.node = node;
        }
    }

    class TrainNode
    implements Comparable<TrainNode> {
        Node node;
        double[][] x;
        int[] y;
        int[] samples;

        public TrainNode(Node node, double[][] x, int[] y, int[] samples) {
            this.node = node;
            this.x = x;
            this.y = y;
            this.samples = samples;
        }

        @Override
        public int compareTo(TrainNode a) {
            return (int)Math.signum((double)(a.node.splitScore - this.node.splitScore));
        }

        public boolean findBestSplit() {
            int label = -1;
            boolean pure = true;
            for (int i = 0; i < this.x.length; ++i) {
                if (this.samples[i] <= 0) continue;
                if (label == -1) {
                    label = this.y[i];
                    continue;
                }
                if (this.y[i] == label) continue;
                pure = false;
                break;
            }
            if (pure) {
                return false;
            }
            int n = 0;
            for (int s : this.samples) {
                n += s;
            }
            if (n <= DecisionTree.this.nodeSize) {
                return false;
            }
            int[] count = new int[DecisionTree.this.k];
            int[] falseCount = new int[DecisionTree.this.k];
            for (int i = 0; i < this.x.length; ++i) {
                if (this.samples[i] <= 0) continue;
                int n2 = this.y[i];
                count[n2] = count[n2] + this.samples[i];
            }
            double impurity = DecisionTree.this.impurity(count, n);
            int p = DecisionTree.this.attributes.length;
            int[] variables = new int[p];
            for (int i = 0; i < p; ++i) {
                variables[i] = i;
            }
            if (DecisionTree.this.mtry < p) {
                Math.permutate((int[])variables);
                for (int j = 0; j < DecisionTree.this.mtry; ++j) {
                    Node split = this.findBestSplit(n, count, falseCount, impurity, variables[j]);
                    if (!(split.splitScore > this.node.splitScore)) continue;
                    this.node.splitFeature = split.splitFeature;
                    this.node.splitValue = split.splitValue;
                    this.node.splitScore = split.splitScore;
                    this.node.trueChildOutput = split.trueChildOutput;
                    this.node.falseChildOutput = split.falseChildOutput;
                }
            } else {
                ArrayList<SplitTask> tasks = new ArrayList<SplitTask>(DecisionTree.this.mtry);
                for (int j = 0; j < DecisionTree.this.mtry; ++j) {
                    tasks.add(new SplitTask(n, count, impurity, variables[j]));
                }
                try {
                    for (Node split : MulticoreExecutor.run(tasks)) {
                        if (!(split.splitScore > this.node.splitScore)) continue;
                        this.node.splitFeature = split.splitFeature;
                        this.node.splitValue = split.splitValue;
                        this.node.splitScore = split.splitScore;
                        this.node.trueChildOutput = split.trueChildOutput;
                        this.node.falseChildOutput = split.falseChildOutput;
                    }
                }
                catch (Exception ex) {
                    for (int j = 0; j < DecisionTree.this.mtry; ++j) {
                        Node split = this.findBestSplit(n, count, falseCount, impurity, variables[j]);
                        if (!(split.splitScore > this.node.splitScore)) continue;
                        this.node.splitFeature = split.splitFeature;
                        this.node.splitValue = split.splitValue;
                        this.node.splitScore = split.splitScore;
                        this.node.trueChildOutput = split.trueChildOutput;
                        this.node.falseChildOutput = split.falseChildOutput;
                    }
                }
            }
            return this.node.splitFeature != -1;
        }

        public Node findBestSplit(int n, int[] count, int[] falseCount, double impurity, int j) {
            Node splitNode = new Node();
            if (DecisionTree.this.attributes[j].getType() == Attribute.Type.NOMINAL) {
                int m = ((NominalAttribute)DecisionTree.this.attributes[j]).size();
                int[][] trueCount = new int[m][DecisionTree.this.k];
                for (int i = 0; i < this.x.length; ++i) {
                    if (this.samples[i] <= 0) continue;
                    int[] nArray = trueCount[(int)this.x[i][j]];
                    int n2 = this.y[i];
                    nArray[n2] = nArray[n2] + this.samples[i];
                }
                for (int l = 0; l < m; ++l) {
                    int tc = Math.sum((int[])trueCount[l]);
                    int fc = n - tc;
                    if (tc < DecisionTree.this.nodeSize || fc < DecisionTree.this.nodeSize) continue;
                    for (int q = 0; q < DecisionTree.this.k; ++q) {
                        falseCount[q] = count[q] - trueCount[l][q];
                    }
                    int trueLabel = Math.whichMax((int[])trueCount[l]);
                    int falseLabel = Math.whichMax((int[])falseCount);
                    double gain = impurity - (double)tc / (double)n * DecisionTree.this.impurity(trueCount[l], tc) - (double)fc / (double)n * DecisionTree.this.impurity(falseCount, fc);
                    if (!(gain > splitNode.splitScore)) continue;
                    splitNode.splitFeature = j;
                    splitNode.splitValue = l;
                    splitNode.splitScore = gain;
                    splitNode.trueChildOutput = trueLabel;
                    splitNode.falseChildOutput = falseLabel;
                }
            } else if (DecisionTree.this.attributes[j].getType() == Attribute.Type.NUMERIC) {
                int[] trueCount = new int[DecisionTree.this.k];
                double prevx = Double.NaN;
                int prevy = -1;
                for (int i : DecisionTree.this.order[j]) {
                    if (this.samples[i] <= 0) continue;
                    if (Double.isNaN(prevx) || this.x[i][j] == prevx || this.y[i] == prevy) {
                        prevx = this.x[i][j];
                        prevy = this.y[i];
                        int n3 = this.y[i];
                        trueCount[n3] = trueCount[n3] + this.samples[i];
                        continue;
                    }
                    int tc = Math.sum((int[])trueCount);
                    int fc = n - tc;
                    if (tc < DecisionTree.this.nodeSize || fc < DecisionTree.this.nodeSize) {
                        prevx = this.x[i][j];
                        prevy = this.y[i];
                        int n4 = this.y[i];
                        trueCount[n4] = trueCount[n4] + this.samples[i];
                        continue;
                    }
                    for (int l = 0; l < DecisionTree.this.k; ++l) {
                        falseCount[l] = count[l] - trueCount[l];
                    }
                    int trueLabel = Math.whichMax((int[])trueCount);
                    int falseLabel = Math.whichMax((int[])falseCount);
                    double gain = impurity - (double)tc / (double)n * DecisionTree.this.impurity(trueCount, tc) - (double)fc / (double)n * DecisionTree.this.impurity(falseCount, fc);
                    if (gain > splitNode.splitScore) {
                        splitNode.splitFeature = j;
                        splitNode.splitValue = (this.x[i][j] + prevx) / 2.0;
                        splitNode.splitScore = gain;
                        splitNode.trueChildOutput = trueLabel;
                        splitNode.falseChildOutput = falseLabel;
                    }
                    prevx = this.x[i][j];
                    prevy = this.y[i];
                    int n5 = this.y[i];
                    trueCount[n5] = trueCount[n5] + this.samples[i];
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + DecisionTree.this.attributes[j].getType());
            }
            return splitNode;
        }

        public boolean split(PriorityQueue<TrainNode> nextSplits) {
            int i;
            int i2;
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            int n = this.x.length;
            int tc = 0;
            int fc = 0;
            int[] trueSamples = new int[n];
            if (DecisionTree.this.attributes[this.node.splitFeature].getType() == Attribute.Type.NOMINAL) {
                for (i2 = 0; i2 < n; ++i2) {
                    if (this.samples[i2] <= 0) continue;
                    if (this.x[i2][this.node.splitFeature] == this.node.splitValue) {
                        trueSamples[i2] = this.samples[i2];
                        tc += trueSamples[i2];
                        this.samples[i2] = 0;
                        continue;
                    }
                    fc += this.samples[i2];
                }
            } else if (DecisionTree.this.attributes[this.node.splitFeature].getType() == Attribute.Type.NUMERIC) {
                for (i2 = 0; i2 < n; ++i2) {
                    if (this.samples[i2] <= 0) continue;
                    if (this.x[i2][this.node.splitFeature] <= this.node.splitValue) {
                        trueSamples[i2] = this.samples[i2];
                        tc += trueSamples[i2];
                        this.samples[i2] = 0;
                        continue;
                    }
                    fc += this.samples[i2];
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + DecisionTree.this.attributes[this.node.splitFeature].getType());
            }
            if (tc < DecisionTree.this.nodeSize || fc < DecisionTree.this.nodeSize) {
                this.node.splitFeature = -1;
                this.node.splitValue = Double.NaN;
                this.node.splitScore = 0.0;
                return false;
            }
            double[] trueChildPosteriori = new double[DecisionTree.this.k];
            double[] falseChildPosteriori = new double[DecisionTree.this.k];
            for (i = 0; i < n; ++i) {
                int yi;
                int n2 = yi = this.y[i];
                trueChildPosteriori[n2] = trueChildPosteriori[n2] + (double)trueSamples[i];
                int n3 = yi;
                falseChildPosteriori[n3] = falseChildPosteriori[n3] + (double)this.samples[i];
            }
            for (i = 0; i < DecisionTree.this.k; ++i) {
                trueChildPosteriori[i] = (trueChildPosteriori[i] + 1.0) / (double)(tc + DecisionTree.this.k);
                falseChildPosteriori[i] = (falseChildPosteriori[i] + 1.0) / (double)(fc + DecisionTree.this.k);
            }
            this.node.trueChild = new Node(this.node.trueChildOutput, trueChildPosteriori);
            this.node.falseChild = new Node(this.node.falseChildOutput, falseChildPosteriori);
            TrainNode trueChild = new TrainNode(this.node.trueChild, this.x, this.y, trueSamples);
            if (tc > DecisionTree.this.nodeSize && trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(trueChild);
                } else {
                    trueChild.split(null);
                }
            }
            TrainNode falseChild = new TrainNode(this.node.falseChild, this.x, this.y, this.samples);
            if (fc > DecisionTree.this.nodeSize && falseChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(falseChild);
                } else {
                    falseChild.split(null);
                }
            }
            double[] dArray = DecisionTree.this.importance;
            int n4 = this.node.splitFeature;
            dArray[n4] = dArray[n4] + this.node.splitScore;
            return true;
        }

        class SplitTask
        implements Callable<Node> {
            int n;
            int[] count;
            double impurity;
            int j;

            SplitTask(int n, int[] count, double impurity, int j) {
                this.n = n;
                this.count = count;
                this.impurity = impurity;
                this.j = j;
            }

            @Override
            public Node call() {
                int[] falseCount = new int[DecisionTree.this.k];
                return TrainNode.this.findBestSplit(this.n, this.count, falseCount, this.impurity, this.j);
            }
        }
    }

    class Node
    implements Serializable {
        int output = -1;
        double[] posteriori = null;
        int splitFeature = -1;
        double splitValue = Double.NaN;
        double splitScore = 0.0;
        Node trueChild = null;
        Node falseChild = null;
        int trueChildOutput = -1;
        int falseChildOutput = -1;

        public Node() {
        }

        public Node(int output, double[] posteriori) {
            this.output = output;
            this.posteriori = posteriori;
        }

        public int predict(double[] x) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (DecisionTree.this.attributes[this.splitFeature].getType() == Attribute.Type.NOMINAL) {
                if (x[this.splitFeature] == this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            if (DecisionTree.this.attributes[this.splitFeature].getType() == Attribute.Type.NUMERIC) {
                if (x[this.splitFeature] <= this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            throw new IllegalStateException("Unsupported attribute type: " + DecisionTree.this.attributes[this.splitFeature].getType());
        }

        public int predict(double[] x, double[] posteriori) {
            if (this.trueChild == null && this.falseChild == null) {
                System.arraycopy(this.posteriori, 0, posteriori, 0, DecisionTree.this.k);
                return this.output;
            }
            if (DecisionTree.this.attributes[this.splitFeature].getType() == Attribute.Type.NOMINAL) {
                if (x[this.splitFeature] == this.splitValue) {
                    return this.trueChild.predict(x, posteriori);
                }
                return this.falseChild.predict(x, posteriori);
            }
            if (DecisionTree.this.attributes[this.splitFeature].getType() == Attribute.Type.NUMERIC) {
                if (x[this.splitFeature] <= this.splitValue) {
                    return this.trueChild.predict(x, posteriori);
                }
                return this.falseChild.predict(x, posteriori);
            }
            throw new IllegalStateException("Unsupported attribute type: " + DecisionTree.this.attributes[this.splitFeature].getType());
        }
    }

    public static enum SplitRule {
        GINI,
        ENTROPY,
        CLASSIFICATION_ERROR;

    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private SplitRule rule = SplitRule.GINI;
        private int nodeSize = 1;
        private int maxNodes = 100;

        public Trainer() {
        }

        public Trainer(int maxNodes) {
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
        }

        public Trainer(Attribute[] attributes, int maxNodes) {
            super(attributes);
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
        }

        public Trainer setSplitRule(SplitRule rule) {
            this.rule = rule;
            return this;
        }

        public Trainer setMaxNodes(int maxNodes) {
            if (maxNodes < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + maxNodes);
            }
            this.maxNodes = maxNodes;
            return this;
        }

        public Trainer setNodeSize(int nodeSize) {
            if (nodeSize < 1) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize);
            }
            this.nodeSize = nodeSize;
            return this;
        }

        public DecisionTree train(double[][] x, int[] y) {
            return new DecisionTree(this.attributes, x, y, this.maxNodes, this.nodeSize, this.rule);
        }
    }
}

