https://github.com/DBWangGroupUNSW/SRS
Tip revision: 0bdf8aa436d237c23d278160dce85c589298daec authored by yifangs on 25 May 2015, 00:30:17 UTC
Update run_toy_data.sh
Update run_toy_data.sh
Tip revision: 0bdf8aa
SRSCoverTree.cpp
/*
* This file is part of SRS project.
*
* SRS is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* SRS is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with SRS. If not, see <http://www.gnu.org/licenses/>.
*
* Created by: Yifang Sun, Jianbin Qin
* Last modified by: Yifang Sun, Jianbin Qin
*/
#include "SRSCoverTree.h"
struct search_node {
float min_dist;
CompressedTreeNode * node;
int id;
bool operator>(const search_node &) const;
bool operator>=(const search_node &) const;
bool operator==(const search_node &) const;
bool operator<=(const search_node &) const;
bool operator<(const search_node &) const;
};
bool search_node::operator>(const search_node &n) const {
return min_dist < n.min_dist;
}
bool search_node::operator>=(const search_node &n) const {
return min_dist <= n.min_dist;
}
bool search_node::operator==(const search_node &n) const {
return min_dist == n.min_dist;
}
bool search_node::operator<=(const search_node &n) const {
return min_dist >= n.min_dist;
}
bool search_node::operator<(const search_node &n) const {
return min_dist > n.min_dist;
}
float base = 1.3; // 2 in cover-tree paper
float il2 = 1. / log(base);
inline float dist_of_scale(int s) {
return pow(base, s);
}
inline int get_scale(float d) {
return (int) ceilf(il2 * log(d));
}
inline int min(int a, int b) {
if (a <= b)
return a;
else
return b;
}
inline float max(float a, float b) {
if (a >= b)
return a;
else
return b;
}
SRS_Cover_Tree::SRS_Cover_Tree(long long n, int d, Proj_data * data) {
this->n = n;
this->isVectorized = false;
this->num_node = 0;
this->data = data;
this->d = d;
this->root = NULL;
this->heap = NULL;
this->query = NULL;
this->pool = NULL;
this->construct();
this->compressed_vectorization();
}
SRS_Cover_Tree::SRS_Cover_Tree(char * file_path) {
read_from_disk_compressed(file_path);
}
SRS_Cover_Tree::~SRS_Cover_Tree() {
free(pool);
delete this->data;
}
TreeNode * SRS_Cover_Tree::new_node(int id) {
TreeNode * node = new TreeNode;
if (isVectorized == true) {
// Cannot Change the tree node structure as it is Vectorized.
// This version only support bulk loading index construction.
fprintf(stderr,
"Cannot new node after cover-tree is already vectorized.\n");
return NULL;
}
node->pivot = id;
node->max_dist = 0;
node->nextSibling = NULL;
node->firstChild = NULL;
num_node++;
return node;
}
float SRS_Cover_Tree::get_max(v_array<ds_node> set) {
float max = -1.0;
for (int i = 0; i < set.index; ++i) {
if (max < set[i].dist.last()) {
max = set[i].dist.last();
}
}
return max;
}
void SRS_Cover_Tree::split(v_array<ds_node> &point_set,
v_array<ds_node> &far_set, int scale) {
unsigned int new_index = 0;
float thres = dist_of_scale(scale);
for (int i = 0; i < point_set.index; i++) {
if (point_set[i].dist.last() <= thres) {
point_set[new_index++] = point_set[i];
} else {
push(far_set, point_set[i]);
}
}
point_set.index = new_index;
}
void SRS_Cover_Tree::dist_split(v_array<ds_node>& point_set,
v_array<ds_node>& new_point_set, int new_pivot,
int scale) {
unsigned int new_index = 0;
float thres = dist_of_scale(scale);
for (int i = 0; i < point_set.index; i++) {
float new_dist = data->cal_dist(new_pivot, point_set[i].id);
if (new_dist <= thres) {
push(point_set[i].dist, new_dist);
push(new_point_set, point_set[i]);
} else {
point_set[new_index++] = point_set[i];
}
}
point_set.index = new_index;
}
TreeNode * SRS_Cover_Tree::batch_insert(int pivot, int max_scale, int top_scale,
v_array<ds_node> &point_set,
v_array<ds_node> &consumed_set,
v_array<v_array<ds_node> > &stack) {
if (point_set.index == 0) {
return new_node(pivot);
}
float max_dist = get_max(point_set);
int next_scale = min(max_scale - 1, get_scale(max_dist));
if (next_scale < -100) { // if points are with distance 0
TreeNode * firstChild = new_node(pivot);
while (point_set.index > 0) {
TreeNode * temp_node = new_node(point_set.last().id);
temp_node->nextSibling = firstChild;
firstChild = temp_node;
push(consumed_set, point_set.last());
point_set.decr();
}
return firstChild;
}
v_array<ds_node> far = pop(stack);
split(point_set, far, max_scale);
TreeNode * child = batch_insert(pivot, next_scale, top_scale, point_set,
consumed_set, stack);
if (point_set.index == 0) {
push(stack, point_set);
point_set = far;
return child;
}
TreeNode * node = new_node(pivot);
v_array<ds_node> new_point_set = pop(stack);
v_array<ds_node> new_consumed_set = pop(stack);
while (point_set.index != 0) {
int new_pivot = point_set.last().id;
push(consumed_set, point_set.last());
point_set.decr();
dist_split(point_set, new_point_set, new_pivot, max_scale);
dist_split(far, new_point_set, new_pivot, max_scale);
TreeNode * new_child = batch_insert(new_pivot, next_scale, top_scale,
new_point_set, new_consumed_set, stack);
new_child->nextSibling = child;
child = new_child;
float thres = dist_of_scale(max_scale);
for (int i = 0; i < new_point_set.index; ++i) {
new_point_set[i].dist.decr();
if (new_point_set[i].dist.last() <= thres) {
push(point_set, new_point_set[i]);
} else {
push(far, new_point_set[i]);
}
}
for (int i = 0; i < new_consumed_set.index; ++i) {
new_consumed_set[i].dist.decr();
push(consumed_set, new_consumed_set[i]);
}
new_point_set.index = 0;
new_consumed_set.index = 0;
}
push(stack, new_point_set);
push(stack, new_consumed_set);
push(stack, point_set);
point_set = far;
node->max_dist = get_max(consumed_set);
node->firstChild = child;
return node;
}
void SRS_Cover_Tree::construct() {
v_array<ds_node> point_set;
v_array<ds_node> consumed_set;
v_array<v_array<ds_node> > stack;
float max_dist = -1.0;
for (int i = 1; i < n; ++i) { // take the first point as pivot
ds_node temp;
temp.id = i;
push(temp.dist, data->cal_dist(0, i));
push(point_set, temp);
if (temp.dist.last() > max_dist) {
max_dist = temp.dist.last();
}
}
this->root = batch_insert(0, get_scale(max_dist), get_scale(max_dist),
point_set, consumed_set, stack);
for (int i = 0; i < consumed_set.index; i++) {
free(consumed_set[i].dist.elements);
}
free(consumed_set.elements);
for (int i = 0; i < stack.index; i++) {
free(stack[i].elements);
}
free(stack.elements);
free(point_set.elements);
}
void SRS_Cover_Tree::init_search(float * query) {
search_node node;
node.min_dist = max(
0,
data->cal_dist(this->compressed_root->pivot, query)
- this->compressed_root->max_dist);
node.node = this->compressed_root;
node.id = -1;
this->query = query;
this->heap = new std::vector<search_node>;
heap->push_back(node);
heap->reserve(n / 100); // Avoid too many re-size of heap by reserving the heap size to be 1% of the total number of points.
}
res_pair SRS_Cover_Tree::increm_knn_search_compressed() {
while (!heap->empty()) {
search_node node = heap->front();
std::pop_heap(heap->begin(), heap->end());
heap->pop_back();
if (node.node == NULL) {
res_pair res = { node.id, node.min_dist };
return res;
} else {
int * external_children = node.node->get_external_children();
for (int i = 0; i < node.node->num_of_external_node; ++i) {
search_node new_node;
new_node.id = external_children[i];
new_node.min_dist = data->cal_dist(new_node.id, query);
new_node.node = NULL;
heap->push_back(new_node);
std::push_heap(heap->begin(), heap->end());
}
CompressedTreeNode * internal_children =
node.node->get_internal_children();
for (int i = 0; i < node.node->num_of_internal_node; ++i) {
search_node new_node;
new_node.node = &internal_children[i];
new_node.min_dist = max(
0,
data->cal_dist(new_node.node->pivot, query)
- new_node.node->max_dist);
new_node.id = -1;
heap->push_back(new_node);
std::push_heap(heap->begin(), heap->end());
}
}
}
res_pair res = { -1, -1 };
return res;
}
void SRS_Cover_Tree::finish_search() {
delete this->heap;
delete[] this->query;
}
void SRS_Cover_Tree::compressed_vectorization(CompressedTreeNode * new_node,
TreeNode * node) {
unsigned char internal_node = 0, external_node = 0;
TreeNode * temp = node->firstChild;
while (temp != NULL) {
if (temp->firstChild == NULL) {
external_node++;
} else {
internal_node++;
}
if (external_node == 256 || internal_node == 256) {
fprintf(
stderr,
"ABORT: fanout of the tree exceed 256, change the type in Line 47-48 in SRSCoverTree.h\n");
return;
}
temp = temp->nextSibling;
}
new_node->pivot = node->pivot;
new_node->max_dist = node->max_dist;
new_node->num_of_internal_node = internal_node;
new_node->num_of_external_node = external_node;
new_node->children = pool + next_available;
next_available += sizeof(int) * external_node
+ sizeof(CompressedTreeNode) * internal_node;
// external nodes
temp = node->firstChild;
int * external_children = new_node->get_external_children();
for (int i = 0; i < external_node; ++i) {
while (temp->firstChild != NULL) {
temp = temp->nextSibling;
}
external_children[i] = temp->pivot;
temp = temp->nextSibling;
}
// internal nodes
temp = node->firstChild;
CompressedTreeNode * internal_children = new_node->get_internal_children();
for (int i = 0; i < internal_node; ++i) {
while (temp->firstChild == NULL) {
TreeNode * another_temp = temp;
temp = temp->nextSibling;
delete another_temp;
}
TreeNode * another_temp = temp->nextSibling;
compressed_vectorization(&internal_children[i], temp);
temp = another_temp;
}
delete node;
}
void SRS_Cover_Tree::compressed_vectorization() {
pool = (char *) malloc(
sizeof(CompressedTreeNode) * (num_node - n) + sizeof(int) * n);
next_available = 0;
CompressedTreeNode * new_root = (CompressedTreeNode *) pool;
next_available += sizeof(CompressedTreeNode);
compressed_vectorization(new_root, root);
//delete root;
compressed_root = new_root;
isVectorized = true;
}
void SRS_Cover_Tree::write_to_disk_compressed(char * file_path) {
char fileex[255];
FILE* fp;
// Write out tree as binary file.
sprintf(fileex, "%s.bin", file_path);
fp = fopen(fileex, "wb");
fwrite(&num_node, sizeof(int), 1, fp);
fwrite(&n, sizeof(long long), 1, fp);
fwrite(&d, sizeof(int), 1, fp);
fwrite(&pool, sizeof(CompressedTreeNode*), 1, fp); // This is for offset.
fwrite(pool, sizeof(CompressedTreeNode) * (num_node - n) + sizeof(int) * n, 1,
fp);
fwrite(data->data, sizeof(float), n * d, fp);
fclose(fp);
}
// called by read_from_disk_compressed
void adjust_offset(CompressedTreeNode * node, char * tmppool, char * pool) {
node->children = (char *) ((char *) pool
+ ((char *) node->children - (char*) tmppool));
CompressedTreeNode * internal_children = node->get_internal_children();
for (int i = 0; i < node->num_of_internal_node; ++i) {
adjust_offset(&internal_children[i], tmppool, pool);
}
}
void SRS_Cover_Tree::read_from_disk_compressed(char * file_path) {
char fileex[255];
FILE* fp;
char* tmppool;
sprintf(fileex, "%s.bin", file_path);
fp = fopen(fileex, "rb");
fread(&num_node, sizeof(int), 1, fp);
fread(&n, sizeof(long long), 1, fp);
fread(&d, sizeof(int), 1, fp);
fread(&tmppool, sizeof(CompressedTreeNode*), 1, fp);
pool = (char *) malloc(
sizeof(CompressedTreeNode) * (num_node - n) + sizeof(int) * n);
fread(pool, 1, sizeof(CompressedTreeNode) * (num_node - n) + sizeof(int) * n,
fp);
// adjust the point offset
this->compressed_root = (CompressedTreeNode *) pool;
adjust_offset(this->compressed_root, tmppool, pool);
// read data
this->data = new Proj_data(n, d, NULL);
fread(this->data->data, sizeof(float), n * d, fp);
fclose(fp);
}
void traverse(CompressedTreeNode * node, int my_height, int & max_height,
int & max_fanout) {
if (node == NULL) { //leaf
if (my_height > max_height) {
max_height = my_height;
}
return;
} else { //internal node
int * external_children = node->get_external_children();
for (int i = 0; i < node->num_of_external_node; ++i) {
traverse(NULL, my_height + 1, max_height, max_fanout);
}
CompressedTreeNode * internal_children = node->get_internal_children();
for (int i = 0; i < node->num_of_internal_node; ++i) {
traverse(&internal_children[i], my_height + 1, max_height, max_fanout);
}
if (node->num_of_internal_node + node->num_of_external_node > max_fanout) {
max_fanout = node->num_of_internal_node + node->num_of_external_node;
}
}
}
void SRS_Cover_Tree::tree_stat() {
int my_height = 0, max_height = 0, max_fanout = 0;
traverse(this->compressed_root, my_height, max_height, max_fanout);
printf("max height: %d\nmax fanout: %d\n", max_height, max_fanout);
}