/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.math.distance.Metric;
import smile.neighbor.KNNSearch;
import smile.neighbor.NearestNeighborSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.RNNSearch;
import smile.sort.DoubleHeapSelect;

public class CoverTree<E>
implements NearestNeighborSearch<E, E>,
KNNSearch<E, E>,
RNNSearch<E, E>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(CoverTree.class);
    private E[] data;
    private Metric<E> distance;
    private Node root;
    private double base = 1.3;
    private double invLogBase = 1.0 / Math.log((double)this.base);
    private boolean identicalExcluded = true;

    public CoverTree(E[] dataset, Metric<E> distance) {
        this(dataset, distance, 1.3);
    }

    public CoverTree(E[] dataset, Metric<E> distance, double base) {
        if (dataset.length == 0) {
            throw new IllegalArgumentException("Empty dataset");
        }
        this.data = dataset;
        this.distance = distance;
        this.base = base;
        this.invLogBase = 1.0 / Math.log((double)base);
        this.buildCoverTree();
    }

    public String toString() {
        return String.format("Cover Tree (%s)", this.distance);
    }

    public CoverTree<E> setIdenticalExcluded(boolean excluded) {
        this.identicalExcluded = excluded;
        return this;
    }

    public boolean isIdenticalExcluded() {
        return this.identicalExcluded;
    }

    private void buildCoverTree() {
        ArrayList<DistanceSet> pointSet = new ArrayList<DistanceSet>();
        ArrayList<DistanceSet> consumedSet = new ArrayList<DistanceSet>();
        E point = this.data[0];
        int idx = 0;
        double maxDist = -1.0;
        for (int i = 1; i < this.data.length; ++i) {
            DistanceSet set = new DistanceSet(i);
            double dist = this.distance.d(point, this.data[i]);
            set.dist.add(dist);
            pointSet.add(set);
            if (!(dist > maxDist)) continue;
            maxDist = dist;
        }
        this.root = this.batchInsert(idx, this.getScale(maxDist), this.getScale(maxDist), pointSet, consumedSet);
    }

    private Node batchInsert(int p, int maxScale, int topScale, ArrayList<DistanceSet> pointSet, ArrayList<DistanceSet> consumedSet) {
        if (pointSet.isEmpty()) {
            Node leaf = this.newLeaf(p);
            return leaf;
        }
        double maxDist = this.max(pointSet);
        int nextScale = Math.min((int)(maxScale - 1), (int)this.getScale(maxDist));
        if (nextScale == Integer.MIN_VALUE) {
            ArrayList<Node> children = new ArrayList<Node>();
            Node leaf = this.newLeaf(p);
            children.add(leaf);
            while (!pointSet.isEmpty()) {
                DistanceSet set = pointSet.get(pointSet.size() - 1);
                pointSet.remove(pointSet.size() - 1);
                leaf = this.newLeaf(set.idx);
                children.add(leaf);
                consumedSet.add(set);
            }
            Node node = new Node(p);
            node.scale = 100;
            node.maxDist = 0.0;
            node.children = children;
            return node;
        }
        ArrayList<DistanceSet> far = new ArrayList<DistanceSet>();
        this.split(pointSet, far, maxScale);
        Node child = this.batchInsert(p, nextScale, topScale, pointSet, consumedSet);
        if (pointSet.isEmpty()) {
            pointSet.addAll(far);
            return child;
        }
        ArrayList<Node> children = new ArrayList<Node>();
        children.add(child);
        ArrayList<DistanceSet> newPointSet = new ArrayList<DistanceSet>();
        ArrayList<DistanceSet> newConsumedSet = new ArrayList<DistanceSet>();
        while (!pointSet.isEmpty()) {
            int i;
            DistanceSet set = pointSet.get(pointSet.size() - 1);
            pointSet.remove(pointSet.size() - 1);
            double newDist = set.dist.get(set.dist.size() - 1);
            consumedSet.add(set);
            this.distSplit(pointSet, newPointSet, set.getObject(), maxScale);
            this.distSplit(far, newPointSet, set.getObject(), maxScale);
            Node newChild = this.batchInsert(set.idx, nextScale, topScale, newPointSet, newConsumedSet);
            newChild.parentDist = newDist;
            children.add(newChild);
            double fmax = this.getCoverRadius(maxScale);
            for (i = 0; i < newPointSet.size(); ++i) {
                set = newPointSet.get(i);
                set.dist.remove(set.dist.size() - 1);
                if (set.dist.get(set.dist.size() - 1) <= fmax) {
                    pointSet.add(set);
                    continue;
                }
                far.add(set);
            }
            for (i = 0; i < newConsumedSet.size(); ++i) {
                set = newConsumedSet.get(i);
                set.dist.remove(set.dist.size() - 1);
                consumedSet.add(set);
            }
            newPointSet.clear();
            newConsumedSet.clear();
        }
        pointSet.addAll(far);
        Node node = new Node(p);
        node.scale = topScale - maxScale;
        node.maxDist = this.max(consumedSet);
        node.children = children;
        return node;
    }

    private double getCoverRadius(int s) {
        return Math.pow((double)this.base, (double)s);
    }

    private int getScale(double d) {
        return (int)Math.ceil((double)(this.invLogBase * Math.log((double)d)));
    }

    private Node newLeaf(int idx) {
        Node leaf = new Node(idx, 0.0, 0.0, null, 100);
        return leaf;
    }

    private double max(ArrayList<DistanceSet> v) {
        double max = 0.0;
        for (DistanceSet n : v) {
            if (!(max < n.dist.get(n.dist.size() - 1))) continue;
            max = n.dist.get(n.dist.size() - 1);
        }
        return max;
    }

    private void split(ArrayList<DistanceSet> pointSet, ArrayList<DistanceSet> farSet, int maxScale) {
        double fmax = this.getCoverRadius(maxScale);
        ArrayList<DistanceSet> newSet = new ArrayList<DistanceSet>();
        for (int i = 0; i < pointSet.size(); ++i) {
            DistanceSet n = pointSet.get(i);
            if (n.dist.get(n.dist.size() - 1) <= fmax) {
                newSet.add(n);
                continue;
            }
            farSet.add(n);
        }
        pointSet.clear();
        pointSet.addAll(newSet);
    }

    private void distSplit(ArrayList<DistanceSet> pointSet, ArrayList<DistanceSet> newPointSet, E newPoint, int maxScale) {
        double fmax = this.getCoverRadius(maxScale);
        ArrayList<DistanceSet> newSet = new ArrayList<DistanceSet>();
        for (int i = 0; i < pointSet.size(); ++i) {
            DistanceSet n = pointSet.get(i);
            double newDist = this.distance.d(newPoint, n.getObject());
            if (newDist <= fmax) {
                pointSet.get((int)i).dist.add(newDist);
                newPointSet.add(n);
                continue;
            }
            newSet.add(n);
        }
        pointSet.clear();
        pointSet.addAll(newSet);
    }

    @Override
    public Neighbor<E, E> nearest(E q) {
        return this.knn(q, 1)[0];
    }

    @Override
    public Neighbor<E, E>[] knn(E q, int k) {
        if (k <= 0) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        if (k > this.data.length) {
            throw new IllegalArgumentException("Neighbor array length is larger than the dataset size");
        }
        Object e = this.root.getObject();
        double d = this.distance.d(e, q);
        Neighbor n1 = new Neighbor(e, e, this.root.idx, d);
        Neighbor[] a1 = (Neighbor[])Array.newInstance(n1.getClass(), 1);
        if (this.root.children == null) {
            a1[0] = n1;
            return a1;
        }
        ArrayList<DistanceNode> currentCoverSet = new ArrayList<DistanceNode>();
        ArrayList<DistanceNode> zeroSet = new ArrayList<DistanceNode>();
        currentCoverSet.add(new DistanceNode(d, this.root));
        DoubleHeapSelect heap = new DoubleHeapSelect(k);
        heap.add(Double.MAX_VALUE);
        boolean emptyHeap = true;
        if (!this.identicalExcluded || this.root.getObject() != q) {
            heap.add(d);
            emptyHeap = false;
        }
        while (!currentCoverSet.isEmpty()) {
            ArrayList<DistanceNode> nextCoverSet = new ArrayList<DistanceNode>();
            for (int i = 0; i < currentCoverSet.size(); ++i) {
                DistanceNode par = (DistanceNode)currentCoverSet.get(i);
                Node parent = ((DistanceNode)currentCoverSet.get((int)i)).node;
                for (int c = 0; c < parent.children.size(); ++c) {
                    double upperBound;
                    Node child = parent.children.get(c);
                    d = c == 0 ? par.dist : this.distance.d(child.getObject(), q);
                    double d2 = upperBound = emptyHeap ? Double.POSITIVE_INFINITY : heap.peek();
                    if (!(d <= upperBound + child.maxDist)) continue;
                    if (c > 0 && d < upperBound && (!this.identicalExcluded || child.getObject() != q)) {
                        heap.add(d);
                    }
                    if (child.children != null) {
                        nextCoverSet.add(new DistanceNode(d, child));
                        continue;
                    }
                    if (!(d <= upperBound)) continue;
                    zeroSet.add(new DistanceNode(d, child));
                }
            }
            currentCoverSet = nextCoverSet;
        }
        ArrayList list = new ArrayList();
        double upperBound = heap.peek();
        for (int i = 0; i < zeroSet.size(); ++i) {
            DistanceNode ds = (DistanceNode)zeroSet.get(i);
            if (!(ds.dist <= upperBound) || this.identicalExcluded && ds.node.getObject() == q) continue;
            e = ds.node.getObject();
            list.add(new Neighbor(e, e, ds.node.idx, ds.dist));
        }
        Object[] neighbors = list.toArray(a1);
        if (neighbors.length < k) {
            logger.warn(String.format("CoverTree.knn(%d) returns only %d neighbors", k, neighbors.length));
        }
        Arrays.sort(neighbors);
        Math.reverse((Object[])neighbors);
        if (neighbors.length > k) {
            neighbors = (Neighbor[])Arrays.copyOf(neighbors, k);
        }
        return neighbors;
    }

    @Override
    public void range(E q, double radius, List<Neighbor<E, E>> neighbors) {
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        ArrayList<DistanceNode> currentCoverSet = new ArrayList<DistanceNode>();
        ArrayList<DistanceNode> zeroSet = new ArrayList<DistanceNode>();
        double d = this.distance.d(this.root.getObject(), q);
        currentCoverSet.add(new DistanceNode(d, this.root));
        while (!currentCoverSet.isEmpty()) {
            ArrayList<DistanceNode> nextCoverSet = new ArrayList<DistanceNode>();
            for (int i = 0; i < currentCoverSet.size(); ++i) {
                DistanceNode par = (DistanceNode)currentCoverSet.get(i);
                Node parent = ((DistanceNode)currentCoverSet.get((int)i)).node;
                for (int c = 0; c < parent.children.size(); ++c) {
                    Node child = parent.children.get(c);
                    d = c == 0 ? par.dist : this.distance.d(child.getObject(), q);
                    if (!(d <= radius + child.maxDist)) continue;
                    if (child.children != null) {
                        nextCoverSet.add(new DistanceNode(d, child));
                        continue;
                    }
                    if (!(d <= radius)) continue;
                    zeroSet.add(new DistanceNode(d, child));
                }
            }
            currentCoverSet = nextCoverSet;
        }
        for (int i = 0; i < zeroSet.size(); ++i) {
            DistanceNode ds = (DistanceNode)zeroSet.get(i);
            if (this.identicalExcluded && ds.node.getObject() == q) continue;
            neighbors.add(new Neighbor(ds.node.getObject(), ds.node.getObject(), ds.node.idx, ds.dist));
        }
    }

    class DistanceNode
    implements Comparable<DistanceNode> {
        double dist;
        Node node;

        DistanceNode(double dist, Node node) {
            this.dist = dist;
            this.node = node;
        }

        @Override
        public int compareTo(DistanceNode o) {
            return (int)Math.signum((double)(this.dist - o.dist));
        }
    }

    class DistanceSet {
        int idx;
        ArrayList<Double> dist;

        DistanceSet() {
            this.dist = new ArrayList();
        }

        DistanceSet(int idx) {
            this.idx = idx;
            this.dist = new ArrayList();
        }

        E getObject() {
            return CoverTree.this.data[this.idx];
        }
    }

    class Node
    implements Serializable {
        int idx;
        double maxDist;
        double parentDist;
        ArrayList<Node> children;
        int scale;

        Node(int idx) {
            this.idx = idx;
        }

        Node(int idx, double maxDist, double parentDist, ArrayList<Node> children, int scale) {
            this.idx = idx;
            this.maxDist = maxDist;
            this.parentDist = parentDist;
            this.children = children;
            this.scale = scale;
        }

        E getObject() {
            return CoverTree.this.data[this.idx];
        }

        boolean isLeaf() {
            return this.children == null;
        }
    }
}

