User:Jdev/Code/VP Tree
From RoboWiki
My implementation of vantage-point tree. It's not well tested and may still has debug code, but generally must work. Published just for history and i recommend to use User:Rednaxela/kD-Tree for kNN search, because it faster and more reliable. But if this code interesting for you - fell free to use and adopt it.
package lxx.utils.vp_tree; import lxx.utils.IntervalDouble; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; import static java.lang.Math.*; public class VPTree { public static int nodesCount = 0; public static int maxLevel = 0; private static final int LOCATIONS_LIMIT = 32; private final IntervalDouble innerBounds = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE); private final IntervalDouble outerBounds = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE); private final int level; private final int dimensionsCount; private final IntervalDouble[] dimensionsWidth; private VPTreeLocation[] locations = new VPTreeLocation[LOCATIONS_LIMIT]; private int size = 0; private VPTree innerSubTree; private VPTree outerSubTree; private double[] vantagePoint; private boolean singular = true; private double splitRadius; public VPTree(int level, int dimensionsCount, int locationsCount) { locations = new VPTreeLocation[LOCATIONS_LIMIT]; size = locationsCount; this.level = level; dimensionsWidth = new IntervalDouble[dimensionsCount]; this.dimensionsCount = dimensionsCount; for (int i = 0; i < dimensionsCount; i++) { dimensionsWidth[i] = new IntervalDouble(Integer.MAX_VALUE, Integer.MIN_VALUE); } nodesCount++; maxLevel = max(maxLevel, level); } public VPTree(int level, int dimensionsCount) { this(level, dimensionsCount, 0); } public void add(VPTreeLocation location) { if (vantagePoint != null) { location.distToLastVantagePoint = getDistance(vantagePoint, location.location); addToChild(location); } else { addImpl(location); } } public VPTreeLocation[] findNearestNeighbours(double[] location, int k) { final NeighborsSet set = new NeighborsSet(location, new VPTreeLocation[k]); findNearestNeighbours(set); return set.neighbors; } private void findNearestNeighbours(NeighborsSet set) { if (vantagePoint == null) { set.visitedNodes++; if (singular) { locations[0].lastDist = getDistance(set.center.location, locations[0].location); for (int i = 1; i < size; i++) { locations[i].lastDist = locations[i - 1].lastDist; } set.add(size, locations); } else { // todo: try find out index of center in locations and go in both directions from it for (int i = 0; i < size; i++) { locations[i].lastDist = getDistance(set.center.location, locations[i].location); set.add(locations[i]); } } } else { final double distToVP = getDistance(vantagePoint, set.center.location); if (distToVP + set.getMaxDistance() < outerBounds.a && set.isFilled()) { innerSubTree.findNearestNeighbours(set); } else if (distToVP - set.getMaxDistance() > innerBounds.b && set.isFilled()) { outerSubTree.findNearestNeighbours(set); } else { if (distToVP < (innerBounds.a + outerBounds.b) / 2) { innerSubTree.findNearestNeighbours(set); if (distToVP + set.getMaxDistance() >= outerBounds.a || !set.isFilled()) { outerSubTree.findNearestNeighbours(set); } } else { outerSubTree.findNearestNeighbours(set); if (distToVP - set.getMaxDistance() <= innerBounds.b || !set.isFilled()) { innerSubTree.findNearestNeighbours(set); } } } } } public List<VPTreeLocation> getAll() { if (vantagePoint == null) { return Arrays.asList(Arrays.copyOf(locations, size)); } else { final List<VPTreeLocation> all = new ArrayList<VPTreeLocation>(); all.addAll(innerSubTree.getAll()); all.addAll(outerSubTree.getAll()); return all; } } private void split() { innerSubTree = new VPTree(level + 1, dimensionsCount); outerSubTree = new VPTree(level + 1, dimensionsCount); splitRadius = getSplitRadius(); for (int i = 0; i < size; i++) { addToChild(locations[i]); } } private void addImpl(VPTreeLocation location) { if (size == locations.length) { final VPTreeLocation[] newLocations = new VPTreeLocation[size * 2]; System.arraycopy(locations, 0, newLocations, 0, locations.length); locations = newLocations; } locations[size++] = location; for (int i = 0; i < dimensionsCount; i++) { dimensionsWidth[i].extend(location.location[i]); singular &= dimensionsWidth[i].a == dimensionsWidth[i].b; } if (!singular && size == LOCATIONS_LIMIT) { split(); } } private void addToChild(VPTreeLocation location) { if (location.distToLastVantagePoint < splitRadius) { innerBounds.extend(location.distToLastVantagePoint); innerSubTree.add(location); } else { outerBounds.extend(location.distToLastVantagePoint); outerSubTree.add(location); } } private double getSplitRadius() { double maxDist = Integer.MIN_VALUE; final VPTreeLocation[] maxDistLocs = new VPTreeLocation[2]; for (int i = 0; i < size; i++) { for (int j = i + 1; j < size; j++) { double dist = getDistance(locations[i].location, locations[j].location); if (dist > maxDist) { maxDist = dist; maxDistLocs[0] = locations[i]; maxDistLocs[1] = locations[j]; } } } int idx = (int) (2 * random()); vantagePoint = Arrays.copyOf(maxDistLocs[idx].location, maxDistLocs[idx].location.length); for (int i = 0; i < size; i++) { locations[i].distToLastVantagePoint = getDistance(vantagePoint, locations[i].location); } Arrays.sort(locations, new Comparator<VPTreeLocation>() { @Override public int compare(VPTreeLocation o1, VPTreeLocation o2) { return (int) signum(o1.distToLastVantagePoint - o2.distToLastVantagePoint); } }); maxDist = Integer.MIN_VALUE; double splitRadius = 0; for (int i = 1; i < size - 2; i++) { double dist = abs(locations[i].distToLastVantagePoint - locations[i + 1].distToLastVantagePoint); if (dist > maxDist) { maxDist = dist; splitRadius = (locations[i].distToLastVantagePoint + locations[i + 1].distToLastVantagePoint) / 2; } } return splitRadius; } private double getDistance(double[] pnt1, double[] pnt2) { double distance = 0; for (int i = 0; i < dimensionsCount; i++) { double d = pnt1[i] - pnt2[i]; distance += d * d; } return sqrt(distance); } public static class VPTreeLocation { private double distToLastVantagePoint; public final double[] location; public double lastDist; public VPTreeLocation(double[] location) { this.location = location; } } public static class NeighborsSet { private final VPTreeLocation center; private final VPTreeLocation[] neighbors; private int visitedNodes = 0; private int locsCount = 0; private int size = 0; public NeighborsSet(double[] center, VPTreeLocation[] neighbors) { this.center = new VPTreeLocation(center); this.neighbors = neighbors; } public void add(VPTreeLocation... entries) { add(entries.length, entries); } public void add(int count, VPTreeLocation... entries) { locsCount++; if (size == neighbors.length && entries[0].lastDist > neighbors[size - 1].lastDist) { return; } int idx = 0; if (size > 0) { idx = findPosition(entries[0]); final int destPos = idx + count; if (destPos < neighbors.length) { System.arraycopy(neighbors, idx, neighbors, destPos, neighbors.length - destPos); } } System.arraycopy(entries, 0, neighbors, idx, min(count, neighbors.length - idx)); size = min(size + count, neighbors.length); } private int findPosition(VPTreeLocation entry) { int idx = size - 1; for (; idx >= 0; idx--) { if (entry.lastDist > neighbors[idx].lastDist) { break; } } return idx < size ? idx + 1 : idx; } public double getMaxDistance() { return size > 0 ? neighbors[size - 1].lastDist : Integer.MAX_VALUE; } public boolean isFilled() { return size == neighbors.length; } } }