/*
*
* Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* 3. All advertising materials mentioning features or use of this software
* must display the following acknowledgement:
* This product includes software developed by the Delft University of Technology.
* 4. Neither the name of the Delft University of Technology nor the names of
* its contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
* OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
* EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
* IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
* OF SUCH DAMAGE.
*
*/
/* This code was adopted with minor modifications from Steve Hanov's great tutorial at http://stevehanov.ca/blog/index.php?id=130 */
#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <stdio.h>
#include <queue>
#include <limits>
#include <cmath>
#ifndef VPTREE_H
#define VPTREE_H
class DataPoint
{
int _ind;
public:
double* _x;
int _D;
DataPoint() {
_D = 1;
_ind = -1;
_x = NULL;
}
DataPoint(int D, int ind, double* x) {
_D = D;
_ind = ind;
_x = (double*) malloc(_D * sizeof(double));
for(int d = 0; d < _D; d++) _x[d] = x[d];
}
DataPoint(const DataPoint& other) { // this makes a deep copy -- should not free anything
if(this != &other) {
_D = other.dimensionality();
_ind = other.index();
_x = (double*) malloc(_D * sizeof(double));
for(int d = 0; d < _D; d++) _x[d] = other.x(d);
}
}
~DataPoint() { if(_x != NULL) free(_x); }
DataPoint& operator= (const DataPoint& other) { // asignment should free old object
if(this != &other) {
if(_x != NULL) free(_x);
_D = other.dimensionality();
_ind = other.index();
_x = (double*) malloc(_D * sizeof(double));
for(int d = 0; d < _D; d++) _x[d] = other.x(d);
}
return *this;
}
int index() const { return _ind; }
int dimensionality() const { return _D; }
double x(int d) const { return _x[d]; }
};
double euclidean_distance(const DataPoint &t1, const DataPoint &t2) {
double dd = .0;
double* x1 = t1._x;
double* x2 = t2._x;
double diff;
for(int d = 0; d < t1._D; d++) {
diff = (x1[d] - x2[d]);
dd += diff * diff;
}
return sqrt(dd);
}
template<typename T, double (*distance)( const T&, const T& )>
class VpTree
{
public:
// Default constructor
VpTree() : _root(0) {}
// Destructor
~VpTree() {
delete _root;
}
// Function to create a new VpTree from data
void create(const std::vector<T>& items) {
delete _root;
_items = items;
_root = buildFromPoints(0, items.size());
}
// Function that uses the tree to find the k nearest neighbors of target
void search(const T& target, int k, std::vector<T>* results, std::vector<double>* distances)
{
// Use a priority queue to store intermediate results on
std::priority_queue<HeapItem> heap;
// Variable that tracks the distance to the farthest point in our results
_tau = DBL_MAX;
// Perform the search
search(_root, target, k, heap);
// Gather final results
results->clear(); distances->clear();
while(!heap.empty()) {
results->push_back(_items[heap.top().index]);
distances->push_back(heap.top().dist);
heap.pop();
}
// Results are in reverse order
std::reverse(results->begin(), results->end());
std::reverse(distances->begin(), distances->end());
}
private:
std::vector<T> _items;
double _tau;
// Single node of a VP tree (has a point and radius; left children are closer to point than the radius)
struct Node
{
int index; // index of point in node
double threshold; // radius(?)
Node* left; // points closer by than threshold
Node* right; // points farther away than threshold
Node() :
index(0), threshold(0.), left(0), right(0) {}
~Node() { // destructor
delete left;
delete right;
}
}* _root;
// An item on the intermediate result queue
struct HeapItem {
HeapItem( int index, double dist) :
index(index), dist(dist) {}
int index;
double dist;
bool operator<(const HeapItem& o) const {
return dist < o.dist;
}
};
// Distance comparator for use in std::nth_element
struct DistanceComparator
{
const T& item;
DistanceComparator(const T& item) : item(item) {}
bool operator()(const T& a, const T& b) {
return distance(item, a) < distance(item, b);
}
};
// Function that (recursively) fills the tree
Node* buildFromPoints( int lower, int upper )
{
if (upper == lower) { // indicates that we're done here!
return NULL;
}
// Lower index is center of current node
Node* node = new Node();
node->index = lower;
if (upper - lower > 1) { // if we did not arrive at leaf yet
// Choose an arbitrary point and move it to the start
int i = (int) ((double)rand() / RAND_MAX * (upper - lower - 1)) + lower;
std::swap(_items[lower], _items[i]);
// Partition around the median distance
int median = (upper + lower) / 2;
std::nth_element(_items.begin() + lower + 1,
_items.begin() + median,
_items.begin() + upper,
DistanceComparator(_items[lower]));
// Threshold of the new node will be the distance to the median
node->threshold = distance(_items[lower], _items[median]);
// Recursively build tree
node->index = lower;
node->left = buildFromPoints(lower + 1, median);
node->right = buildFromPoints(median, upper);
}
// Return result
return node;
}
// Helper function that searches the tree
void search(Node* node, const T& target, int k, std::priority_queue<HeapItem>& heap)
{
if(node == NULL) return; // indicates that we're done here
// Compute distance between target and current node
double dist = distance(_items[node->index], target);
// If current node within radius tau
if(dist < _tau) {
if(heap.size() == k) heap.pop(); // remove furthest node from result list (if we already have k results)
heap.push(HeapItem(node->index, dist)); // add current node to result list
if(heap.size() == k) _tau = heap.top().dist; // update value of tau (farthest point in result list)
}
// Return if we arrived at a leaf
if(node->left == NULL && node->right == NULL) {
return;
}
// If the target lies within the radius of ball
if(dist < node->threshold) {
if(dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child first
search(node->left, target, k, heap);
}
if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child
search(node->right, target, k, heap);
}
// If the target lies outsize the radius of the ball
} else {
if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child first
search(node->right, target, k, heap);
}
if (dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child
search(node->left, target, k, heap);
}
}
}
};
#endif