/*
 * Decompiled with CFR 0.152.
 */
package org.tinspin.index.kdtree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import org.tinspin.index.Index;
import org.tinspin.index.IndexConfig;
import org.tinspin.index.PointDistance;
import org.tinspin.index.PointMap;
import org.tinspin.index.PointMultimap;
import org.tinspin.index.Stats;
import org.tinspin.index.kdtree.KDIterator;
import org.tinspin.index.kdtree.KDIteratorKnn;
import org.tinspin.index.kdtree.Node;
import org.tinspin.index.util.MutableRef;
import org.tinspin.index.util.StringBuilderLn;

public class KDTree<T>
implements PointMap<T>,
PointMultimap<T> {
    public static final boolean DEBUG = false;
    private final int dims;
    private final boolean defensiveKeyCopy;
    private int size = 0;
    private int modCount = 0;
    private long nDist1NN = 0L;
    private long nDistKNN = 0L;
    private boolean invariantBroken = false;
    private Node<T> root;
    private static final Comparator<Index.PointEntryKnn<?>> compKnn = (point1, point2) -> {
        double deltaDist = point1.dist() - point2.dist();
        return deltaDist < 0.0 ? -1 : (deltaDist > 0.0 ? 1 : 0);
    };

    private KDTree(int dims, boolean defensiveKeyCopy) {
        this.dims = dims;
        this.defensiveKeyCopy = defensiveKeyCopy;
    }

    public static <T> KDTree<T> create(int dims) {
        return new KDTree<T>(dims, true);
    }

    public static <T> KDTree<T> create(IndexConfig config) {
        return new KDTree<T>(config.getDimensions(), config.getDefensiveKeyCopy());
    }

    @Override
    public void insert(double[] key, T value) {
        ++this.size;
        ++this.modCount;
        if (this.root == null) {
            this.root = new Node<T>(key, value, 0, this.defensiveKeyCopy);
            return;
        }
        Node<T> n = this.root;
        while ((n = n.getClosestNodeOrAddPoint(key, value, this.dims, this.defensiveKeyCopy)) != null) {
        }
    }

    @Override
    public boolean contains(double[] key) {
        return this.findNodeExact(key, new RemoveResult(), e -> true) != null;
    }

    @Override
    public KDIterator<T> queryExactPoint(double[] point) {
        return this.query(point, point);
    }

    @Override
    public T queryExact(double[] key) {
        Node e = this.findNodeExact(key, new RemoveResult(), entry -> true);
        return e == null ? null : (T)e.value();
    }

    private Node<T> findNodeExact(double[] key, RemoveResult<T> resultDepth, Predicate<Index.PointEntry<T>> filter) {
        if (this.root == null) {
            return null;
        }
        return this.invariantBroken ? this.findNodeExactSlow(key, this.root, null, resultDepth, filter) : this.findNodeExactFast(key, null, resultDepth, filter);
    }

    private Node<T> findNodeExactFast(double[] key, Node<T> parent, RemoveResult<T> resultDepth, Predicate<Index.PointEntry<T>> filter) {
        double nodeX;
        double keyX;
        Node<T> n = this.root;
        do {
            double[] nodeKey = n.point();
            nodeX = nodeKey[n.getDim()];
            keyX = key[n.getDim()];
            if (keyX == nodeX && Arrays.equals(key, nodeKey) && filter.test(n)) {
                resultDepth.pos = n.getDim();
                resultDepth.nodeParent = parent;
                return n;
            }
            parent = n;
        } while ((n = keyX >= nodeX ? n.getHi() : n.getLo()) != null);
        return n;
    }

    private Node<T> findNodeExactSlow(double[] key, Node<T> n, Node<T> parent, RemoveResult<T> resultDepth, Predicate<Index.PointEntry<T>> filter) {
        double nodeX;
        double keyX;
        do {
            double[] nodeKey = n.point();
            nodeX = nodeKey[n.getDim()];
            keyX = key[n.getDim()];
            if (keyX == nodeX) {
                Node<T> n2;
                if (Arrays.equals(key, nodeKey) && filter.test(n)) {
                    resultDepth.pos = n.getDim();
                    resultDepth.nodeParent = parent;
                    return n;
                }
                if (n.getLo() != null && (n2 = this.findNodeExactSlow(key, n.getLo(), n, resultDepth, filter)) != null) {
                    return n2;
                }
            }
            parent = n;
        } while ((n = keyX >= nodeX ? n.getHi() : n.getLo()) != null);
        return n;
    }

    @Override
    public boolean remove(double[] key, T value) {
        return this.removeIf(key, e -> Objects.equals(e.value(), value));
    }

    @Override
    public T remove(double[] key) {
        MutableRef ref = new MutableRef();
        this.removeIf(key, e -> {
            ref.set(e.value());
            return true;
        });
        return ref.get();
    }

    @Override
    public boolean removeIf(double[] key, Predicate<Index.PointEntry<T>> pred) {
        if (this.root == null) {
            return false;
        }
        this.invariantBroken = true;
        RemoveResult removeResult = new RemoveResult();
        Node eToRemove = this.findNodeExact(key, removeResult, pred);
        if (eToRemove == null) {
            return false;
        }
        ++this.modCount;
        Object value = eToRemove.value();
        if (eToRemove == this.root && this.size == 1) {
            this.root = null;
            this.size = 0;
            this.invariantBroken = false;
            return true;
        }
        while (eToRemove != null && !eToRemove.isLeaf()) {
            int pos = removeResult.pos;
            removeResult.node = null;
            if (eToRemove.getHi() != null) {
                removeResult.best = Double.POSITIVE_INFINITY;
                this.removeMinLeaf(eToRemove.getHi(), eToRemove, pos, removeResult);
            } else if (eToRemove.getLo() != null) {
                removeResult.best = Double.NEGATIVE_INFINITY;
                this.removeMaxLeaf(eToRemove.getLo(), eToRemove, pos, removeResult);
            }
            eToRemove.set(removeResult.node.point(), removeResult.node.value());
            eToRemove = removeResult.node;
        }
        Node parent = removeResult.nodeParent;
        if (parent != null) {
            if (parent.getLo() == eToRemove) {
                parent.setLeft(null);
            } else if (parent.getHi() == eToRemove) {
                parent.setRight(null);
            } else {
                throw new IllegalStateException();
            }
        }
        --this.size;
        return true;
    }

    private void removeMinLeaf(Node<T> node, Node<T> parent, int pos, RemoveResult<T> result) {
        if (pos == node.getDim()) {
            if (node.getLo() != null) {
                this.removeMinLeaf(node.getLo(), node, pos, result);
            } else if (node.point()[pos] <= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = node.point()[pos];
                result.pos = node.getDim();
            }
        } else {
            double localX = node.point()[pos];
            if (localX <= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = localX;
                result.pos = node.getDim();
            }
            if (node.getLo() != null) {
                this.removeMinLeaf(node.getLo(), node, pos, result);
            }
            if (node.getHi() != null) {
                this.removeMinLeaf(node.getHi(), node, pos, result);
            }
        }
    }

    private void removeMaxLeaf(Node<T> node, Node<T> parent, int pos, RemoveResult<T> result) {
        if (pos == node.getDim()) {
            if (node.getHi() != null) {
                this.removeMaxLeaf(node.getHi(), node, pos, result);
            } else if (node.point()[pos] >= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = node.point()[pos];
                result.pos = node.getDim();
            }
        } else {
            double localX = node.point()[pos];
            if (localX >= result.best) {
                result.node = node;
                result.nodeParent = parent;
                result.best = localX;
                result.pos = node.getDim();
            }
            if (node.getLo() != null) {
                this.removeMaxLeaf(node.getLo(), node, pos, result);
            }
            if (node.getHi() != null) {
                this.removeMaxLeaf(node.getHi(), node, pos, result);
            }
        }
    }

    @Override
    public T update(double[] oldKey, double[] newKey) {
        if (this.root == null) {
            return null;
        }
        T value = this.remove(oldKey);
        if (value != null) {
            this.insert(newKey, value);
            return value;
        }
        return null;
    }

    @Override
    public boolean update(double[] oldKey, double[] newKey, T value) {
        if (this.root == null) {
            return false;
        }
        if (this.remove(oldKey, value)) {
            this.insert(newKey, value);
            return true;
        }
        return false;
    }

    @Override
    public boolean contains(double[] key, T value) {
        return this.findNodeExact(key, new RemoveResult(), e -> Objects.equals(value, e.value())) != null;
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public void clear() {
        this.size = 0;
        this.root = null;
        this.invariantBroken = false;
        ++this.modCount;
    }

    @Override
    public KDIterator<T> query(double[] min2, double[] max) {
        return new KDIterator(this, min2, max);
    }

    static boolean isEnclosed(double[] point, double[] min2, double[] max) {
        for (int i = 0; i < point.length; ++i) {
            if (!(point[i] < min2[i]) && !(point[i] > max[i])) continue;
            return false;
        }
        return true;
    }

    private List<Index.PointEntryKnn<T>> knnQuery(double[] center, int k, PointDistance distFn) {
        if (this.root == null) {
            return Collections.emptyList();
        }
        ArrayList<Index.PointEntryKnn<T>> candidates = new ArrayList<Index.PointEntryKnn<T>>(k);
        this.rangeSearchKNN(this.root, center, candidates, k, Double.POSITIVE_INFINITY, distFn);
        return candidates;
    }

    private double rangeSearchKNN(Node<T> node, double[] center, ArrayList<Index.PointEntryKnn<T>> candidates, int k, double maxRange, PointDistance distFn) {
        int pos = node.getDim();
        if (node.getLo() != null && (center[pos] < node.point()[pos] || node.getHi() == null)) {
            maxRange = this.rangeSearchKNN(node.getLo(), center, candidates, k, maxRange, distFn);
            if (center[pos] + maxRange >= node.point()[pos]) {
                maxRange = this.addCandidate(node, center, candidates, k, maxRange, distFn);
                if (node.getHi() != null) {
                    maxRange = this.rangeSearchKNN(node.getHi(), center, candidates, k, maxRange, distFn);
                }
            }
        } else if (node.getHi() != null) {
            maxRange = this.rangeSearchKNN(node.getHi(), center, candidates, k, maxRange, distFn);
            if (center[pos] <= node.point()[pos] + maxRange) {
                maxRange = this.addCandidate(node, center, candidates, k, maxRange, distFn);
                if (node.getLo() != null) {
                    maxRange = this.rangeSearchKNN(node.getLo(), center, candidates, k, maxRange, distFn);
                }
            }
        } else {
            maxRange = this.addCandidate(node, center, candidates, k, maxRange, distFn);
        }
        return maxRange;
    }

    private double addCandidate(Node<T> node, double[] center, ArrayList<Index.PointEntryKnn<T>> candidates, int k, double maxRange, PointDistance distFn) {
        Index.PointEntryKnn<T> cand;
        ++this.nDistKNN;
        double dist = distFn.dist(center, node.point());
        if (dist > maxRange) {
            return maxRange;
        }
        if (dist == maxRange && candidates.size() >= k) {
            return maxRange;
        }
        if (candidates.size() >= k) {
            cand = candidates.remove(k - 1);
            cand.set(node, dist);
        } else {
            cand = new Index.PointEntryKnn<T>(node, dist);
        }
        int insertionPos = Collections.binarySearch(candidates, cand, compKnn);
        insertionPos = insertionPos >= 0 ? insertionPos : -(insertionPos + 1);
        candidates.add(insertionPos, cand);
        return candidates.size() < k ? maxRange : candidates.get(candidates.size() - 1).dist();
    }

    @Override
    public String toStringTree() {
        StringBuilderLn sb = new StringBuilderLn();
        if (this.root == null) {
            sb.append("empty tree");
        } else {
            this.toStringTree(sb, this.root, 0);
        }
        return sb.toString();
    }

    private void toStringTree(StringBuilderLn sb, Node<T> node, int depth) {
        if (node.getLo() != null) {
            this.toStringTree(sb, node.getLo(), depth + 1);
        }
        for (int i = 0; i < depth; ++i) {
            sb.append(".");
        }
        sb.append(" ");
        sb.append(Arrays.toString(node.point()));
        sb.append(" v=").append(node.value());
        sb.append(" l/r=");
        sb.append(node.getLo() == null ? null : Arrays.toString(node.getLo().point()));
        sb.append("/");
        sb.append(node.getHi() == null ? null : Arrays.toString(node.getHi().point()));
        sb.appendLn();
        if (node.getHi() != null) {
            this.toStringTree(sb, node.getHi(), depth + 1);
        }
    }

    public String toString() {
        return "KDTree;size=" + this.size + ";DEBUG=false;center=" + (this.root == null ? "null" : Arrays.toString(this.root.point()));
    }

    @Override
    public KDStats getStats() {
        KDStats s2 = new KDStats(this);
        if (this.root != null) {
            this.root.checkNode(s2, 0);
        }
        return s2;
    }

    @Override
    public int getDims() {
        return this.dims;
    }

    @Override
    public Index.PointIterator<T> iterator() {
        if (this.root == null) {
            return this.query(new double[this.dims], new double[this.dims]);
        }
        throw new UnsupportedOperationException();
    }

    @Override
    public Index.PointEntryKnn<T> query1nn(double[] center) {
        return (Index.PointEntryKnn)this.queryKnn(center, 1).next();
    }

    @Override
    public Index.PointIteratorKnn<T> queryKnn(double[] center, int k) {
        if (this.size < 1000000 && k <= 10) {
            return new KDQueryIteratorKnn(this, center, k, PointDistance.L2);
        }
        return new KDIteratorKnn<T>(this.root, k, center, PointDistance.L2, (e, d) -> true);
    }

    @Override
    public Index.PointIteratorKnn<T> queryKnn(double[] center, int k, PointDistance distFn) {
        if (this.size < 1000000 && k <= 10) {
            return new KDQueryIteratorKnn(this, center, k, distFn);
        }
        return new KDIteratorKnn<T>(this.root, k, center, distFn, (e, d) -> true);
    }

    @Override
    public int getNodeCount() {
        return this.getStats().getNodeCount();
    }

    @Override
    public int getDepth() {
        return this.getStats().getMaxDepth();
    }

    Node<T> getRoot() {
        return this.root;
    }

    public static class KDStats
    extends Stats {
        public KDStats(KDTree<?> tree) {
            super(tree.nDist1NN + tree.nDistKNN, tree.nDist1NN, tree.nDistKNN);
        }
    }

    private static class KDQueryIteratorKnn<T>
    implements Index.PointIteratorKnn<T> {
        private Iterator<Index.PointEntryKnn<T>> it;
        private final KDTree<T> tree;
        private final PointDistance distFn;

        public KDQueryIteratorKnn(KDTree<T> tree, double[] center, int k, PointDistance distFn) {
            this.tree = tree;
            this.distFn = distFn;
            this.reset(center, k);
        }

        @Override
        public boolean hasNext() {
            return this.it.hasNext();
        }

        @Override
        public Index.PointEntryKnn<T> next() {
            return this.it.next();
        }

        @Override
        public KDQueryIteratorKnn<T> reset(double[] center, int k) {
            this.it = this.tree.knnQuery(center, k, this.distFn).iterator();
            return this;
        }
    }

    private static class RemoveResult<T> {
        Node<T> node = null;
        Node<T> nodeParent = null;
        double best;
        int pos;

        private RemoveResult() {
        }
    }
}

