https://github.com/lorenzhs/BuRR
Tip revision: 1c62832ad7d6eab5b337f386955868c3ce9a54ea authored by Lorenz Hübschle-Schneider on 11 September 2021, 12:56:53 UTC
README: paper link, bibtex
README: paper link, bibtex
Tip revision: 1c62832
construction.hpp
// Copyright (c) Lorenz Hübschle-Schneider
// Copyright (c) Facebook, Inc. and its affiliates.
// All Rights Reserved. This source code is licensed under the Apache 2.0
// License (found in the LICENSE file in the root directory).
#pragma once
#include "minimal_hasher.hpp"
#include "rocksdb/math.h"
#include "rocksdb/stop_watch.h"
#include "sorter.hpp"
#include <tlx/logger.hpp>
#include <tlx/logger/wrap_unprintable.hpp>
#ifndef RIBBON_USE_STD_SORT
// Use in-place super-scalar radix sorter ips2ra, which is around 3x faster for
// the inputs used here
#include <ips2ra.hpp>
#endif
#include <algorithm>
#include <cassert>
#include <functional>
#include <tuple>
#include <vector>
namespace ribbon {
template <bool kFCA1, typename BandingStorage>
std::pair<bool, typename BandingStorage::Index>
BandingAdd(BandingStorage *bs, typename BandingStorage::Index start,
typename BandingStorage::CoeffRow coeffs,
typename BandingStorage::ResultRow result) {
using Index = typename BandingStorage::Index;
constexpr bool debug = false;
Index pos = start;
if constexpr (!kFCA1) {
int tz = rocksdb::CountTrailingZeroBits(coeffs);
pos += static_cast<size_t>(tz);
coeffs >>= tz;
}
while (true) {
assert(pos < bs->GetNumSlots());
assert((coeffs & 1) == 1);
auto other = bs->GetCoeffs(pos);
if (other == 0) {
// found an empty slot, insert
bs->SetCoeffs(pos, coeffs);
bs->SetResult(pos, result);
sLOG << "Insertion succeeded at position" << pos;
return std::make_pair(true, pos);
}
assert((other & 1) == 1);
coeffs ^= other;
result ^= bs->GetResult(pos);
if (coeffs == 0) {
// linearly dependent!
sLOG << "Insertion failed at position" << pos;
return std::make_pair(false, pos);
}
// move to now-leading coefficient
int tz = rocksdb::CountTrailingZeroBits(coeffs);
pos += static_cast<size_t>(tz);
coeffs >>= tz;
}
LOG << "Insertion hit the end of the loop";
return std::pair(result == 0, pos);
}
// hack to prevent inlining of ips2ra, which is awful for compile time and
// produces ginormous binaries
template <typename Iterator>
__attribute__((noinline)) void my_sort(Iterator begin, Iterator end) {
#ifdef RIBBON_USE_STD_SORT
// Use std::sort as a slow fallback
std::sort(begin, end, [](const auto &a, const auto &b) {
return std::get<0>(a) < std::get<0>(b);
});
#else
ips2ra::sort(begin, end, [](const auto &x) { return std::get<0>(x); });
#endif
}
// hack to prevent inlining of ips2ra, which is awful for compile time and
// produces ginormous binaries
template <typename Iterator, typename Hasher, typename Index>
__attribute__((noinline)) void my_sort(Iterator begin, Iterator end,
const Hasher &h, Index num_starts) {
unsigned sparse_shift = 0;
if constexpr (Hasher::kSparseCoeffs) {
sparse_shift = Hasher::shift_;
}
MinimalHasher<Index, Hasher::kSparseCoeffs> mh(Hasher::kBucketSize,
h.GetFactor(), sparse_shift);
return Sorter<Index, Hasher::kIsFilter, Hasher::kSparseCoeffs,
std::conditional_t<Hasher::kIsFilter, SorterDummyData,
typename Hasher::ResultRow>>()
.do_sort(begin, end, mh, num_starts);
}
template <typename BandingStorage, typename Hasher, typename Iterator,
typename BumpStorage = std::vector<typename std::iterator_traits<Iterator>::value_type>>
bool BandingAddRange(BandingStorage *bs, Hasher &hasher, Iterator begin,
Iterator end, BumpStorage *bump_vec) {
using CoeffRow = typename BandingStorage::CoeffRow;
using Index = typename BandingStorage::Index;
using ResultRow = typename BandingStorage::ResultRow;
using Hash = typename Hasher::Hash;
constexpr bool kFCA1 = Hasher::kFirstCoeffAlwaysOne;
constexpr bool oneBitThresh = Hasher::kThreshMode == ThreshMode::onebit;
constexpr bool debug = false;
constexpr bool log = Hasher::log;
if (begin == end)
return true;
rocksdb::StopWatchNano timer(true);
const Index num_starts = bs->GetNumStarts();
const Index num_buckets = bs->GetNumBuckets();
sLOG << "Constructing ribbon with" << num_buckets
<< "buckets, num_starts = " << num_starts;
const auto num_items = end - begin; // std::distance(begin, end);
#ifdef RIBBON_PASS_HASH
constexpr bool sparse = Hasher::kSparseCoeffs && Hasher::kCoeffBits < 128;
auto input = std::make_unique<
std::tuple<Index, Index, std::conditional_t<sparse, uint32_t, Hash>>[]>(
num_items);
#else
auto input = std::make_unique<std::pair<Index, Index>[]>(num_items);
#endif
{
sLOG << "Processing" << num_items << "items";
for (Index i = 0; i < static_cast<Index>(num_items); i++) {
const Hash h = hasher.GetHash(*(begin + i));
const Index start = hasher.GetStart(h, num_starts);
const Index sortpos = Hasher::StartToSort(start);
#ifdef RIBBON_PASS_HASH
if constexpr (sparse) {
uint32_t compact_hash = hasher.GetCompactHash(h);
input[i] = std::make_tuple(sortpos, i, compact_hash);
} else {
input[i] = std::make_tuple(sortpos, i, h);
}
#else
input[i] = std::make_pair(sortpos, i);
#endif
}
}
LOGC(log) << "\tInput transformation took "
<< timer.ElapsedNanos(true) / 1e6 << "ms";
my_sort(input.get(), input.get() + num_items);
LOGC(log) << "\tSorting took " << timer.ElapsedNanos(true) / 1e6 << "ms";
const auto do_bump = [&](auto &vec) {
sLOG << "Bumping" << vec.size() << "items";
for (auto [row, idx] : vec) {
sLOG << "\tBumping row" << row << "item"
<< tlx::wrap_unprintable(*(begin + idx));
bs->SetCoeffs(row, 0);
bs->SetResult(row, 0);
bump_vec->push_back(*(begin + idx));
}
vec.clear();
};
Index last_bucket = 0;
bool all_good = true;
Index thresh = Hasher::NoBumpThresh();
// Bump cache (row, input item) pairs that may have to be bumped retroactively
Index last_cval = -1;
std::vector<std::pair<Index, Index>> bump_cache;
// For 1-bit thresholds, we also need an uncompressed bump cache for undoing
// all insertions with the same uncompressed value if we end up in the
// "plus" case with a separately stored threshold
[[maybe_unused]] Index last_val = -1;
[[maybe_unused]] std::conditional_t<oneBitThresh, decltype(bump_cache), int> unc_bump_cache;
#ifndef RIBBON_PASS_HASH
auto next = *(begin + input[0].second);
#endif
for (Index i = 0; i < static_cast<Index>(num_items); ++i) {
#ifdef RIBBON_PASS_HASH
const auto [sortpos, idx, hash] = input[i];
#else
const auto [sortpos, idx] = input[i];
#endif
const Index start = Hasher::SortToStart(sortpos),
bucket = Hasher::GetBucket(sortpos),
val = Hasher::GetIntraBucket(sortpos),
cval = hasher.Compress(val);
assert(bucket >= last_bucket);
assert(oneBitThresh || cval < Hasher::NoBumpThresh());
#ifndef RIBBON_PASS_HASH
const Hash hash = hasher.GetHash(next);
if (i + 1 < num_items)
next = *(begin + input[i + 1].second);
// prefetch the cache miss far in advance, assuming the iterator
// is to contiguous storage
if (TLX_LIKELY(i + 32 < num_items))
__builtin_prefetch(&*begin + input[i + 32].second, 0, 1);
#endif
if (bucket != last_bucket) {
// moving to next bucket
sLOG << "Moving to bucket" << bucket << "was" << last_bucket;
if constexpr (oneBitThresh) {
unc_bump_cache.clear();
last_val = val;
}
if (thresh == Hasher::NoBumpThresh()) {
sLOG << "Bucket" << last_bucket << "has no bumped items";
bs->SetMeta(last_bucket, thresh);
}
all_good = true;
last_bucket = bucket;
thresh = Hasher::NoBumpThresh(); // maximum == "no bumpage"
last_cval = cval;
bump_cache.clear();
} else if (!all_good) {
// direct hard bump
sLOG << "Directly bumping" << tlx::wrap_unprintable(*(begin + idx))
<< "from bucket" << bucket << "val" << val << cval << "start"
<< start << "sort" << sortpos << "hash" << std::hex << hash
<< "data"
<< (uint64_t)(Hasher::kIsFilter
? hasher.GetResultRowFromHash(hash)
: hasher.GetResultRowFromInput(*(begin + idx)))
<< std::dec;
bump_vec->push_back(*(begin + idx));
continue;
} else if (cval != last_cval) {
// clear bump cache
sLOG << "Bucket" << bucket << "cval" << cval << "!=" << last_cval;
bump_cache.clear();
last_cval = cval;
}
if constexpr (oneBitThresh) {
// split into constexpr and normal if because unc_bump_cache isn't a
// vector if !oneBitThresh
if (val != last_val) {
unc_bump_cache.clear();
last_val = val;
}
}
const CoeffRow cr = hasher.GetCoeffs(hash);
const ResultRow rr = Hasher::kIsFilter
? hasher.GetResultRowFromHash(hash)
: hasher.GetResultRowFromInput(*(begin + idx));
auto [success, row] = BandingAdd<kFCA1>(bs, start, cr, rr);
if (!success) {
assert(all_good);
if (bump_vec == nullptr)
// bumping disabled, abort!
return false;
// if we got this far, this is the first failure in this bucket,
// and we need to undo insertions with the same cval
sLOG << "First failure in bucket" << bucket << "val" << val
<< "start" << start << "sort" << sortpos << "hash" << std::hex
<< hash << "data" << (uint64_t)rr << std::dec << "for item"
<< tlx::wrap_unprintable(*(begin + idx)) << "-> threshold"
<< cval << "clash in row" << row;
thresh = cval;
if constexpr (oneBitThresh) {
if (cval == 2) {
sLOG << "First failure in bucket" << bucket << "val" << val
<< "is a 'plus' case (below threshold)";
// "plus" case: store uncompressed threshold in hash table
hasher.Set(bucket, val);
// Set meta to 0 (some bumpage) but keep thresh at 2 so that
// we don't retroactively bump everything when moving to the
// next bucket
bs->SetMeta(bucket, 0);
all_good = false;
// bump all items with the same uncompressed value
do_bump(unc_bump_cache);
sLOG << "Also bumping"
<< tlx::wrap_unprintable(*(begin + idx));
bump_vec->push_back(*(begin + idx));
// proceed to next item, don't do regular bumping (but only
// if cval == 2, not generally!)
continue;
}
}
bs->SetMeta(bucket, thresh);
all_good = false;
do_bump(bump_cache);
bump_vec->push_back(*(begin + idx));
} else {
sLOG << "Insertion succeeded of item"
<< tlx::wrap_unprintable(*(begin + idx)) << "in pos" << row
<< "bucket" << bucket << "val" << val << cval << "start"
<< start << "sort" << sortpos << "hash" << std::hex << hash
<< "data" << (uint64_t)rr << std::dec;
bump_cache.emplace_back(row, idx);
if constexpr (oneBitThresh) {
// also record in bump cache for uncompressed values
unc_bump_cache.emplace_back(row, idx);
}
}
}
// set final threshold
if (thresh == Hasher::NoBumpThresh()) {
bs->SetMeta(last_bucket, thresh);
}
// migrate thresholds to hash table
if constexpr (oneBitThresh) {
hasher.Finalise(num_buckets);
}
LOGC(log) << "\tActual insertion took " << timer.ElapsedNanos(true) / 1e6
<< "ms";
return true;
}
template <typename BandingStorage, typename Hasher, typename Iterator,
typename BumpStorage = std::vector<typename Hasher::mhc_t>>
bool BandingAddRangeMHC(BandingStorage *bs, Hasher &hasher, Iterator begin,
Iterator end, BumpStorage *bump_vec) {
static_assert(Hasher::kUseMHC, "you called the wrong method");
using CoeffRow = typename BandingStorage::CoeffRow;
using Index = typename BandingStorage::Index;
using ResultRow = typename BandingStorage::ResultRow;
constexpr bool kFCA1 = Hasher::kFirstCoeffAlwaysOne;
constexpr bool oneBitThresh = (Hasher::kThreshMode == ThreshMode::onebit);
constexpr bool debug = false;
constexpr bool log = Hasher::log;
if (begin == end)
return true;
rocksdb::StopWatchNano timer(true);
const auto num_items = end - begin;
const Index num_starts = bs->GetNumStarts();
const Index num_buckets = bs->GetNumBuckets();
sLOG << "Constructing ribbon (MHC) with" << num_buckets
<< "buckets, num_starts =" << num_starts;
my_sort(begin, end, hasher, num_starts);
// MHCs should be unique, if not, fail construction
if constexpr (Hasher::kIsFilter) {
assert(std::adjacent_find(begin, end) == end);
} else {
assert(std::adjacent_find(begin, end, [](const auto &a, const auto &b) {
return a.first == b.first;
}) == end);
}
LOGC(log) << "\tSorting took " << timer.ElapsedNanos(true) / 1e6 << "ms";
const auto do_bump = [&](auto &vec) {
sLOG << "Bumping" << vec.size() << "items";
for (auto [row, idx] : vec) {
sLOG << "\tBumping row" << row << "item"
<< tlx::wrap_unprintable(*(begin + idx));
bs->SetCoeffs(row, 0);
bs->SetResult(row, 0);
bump_vec->push_back(*(begin + idx));
}
vec.clear();
};
Index last_bucket = 0;
bool all_good = true;
Index thresh = Hasher::NoBumpThresh();
// Bump cache (row, input item) pairs that may have to be bumped retroactively
Index last_cval = -1;
std::vector<std::pair<Index, Index>> bump_cache;
// For 1-bit thresholds, we also need an uncompressed bump cache for undoing
// all insertions with the same uncompressed value if we end up in the
// "plus" case with a separately stored threshold
[[maybe_unused]] Index last_val = -1;
[[maybe_unused]] std::conditional_t<oneBitThresh, decltype(bump_cache), int> unc_bump_cache;
for (Index i = 0; i < num_items; ++i) {
const auto mhc = *(begin + i);
const auto hash = hasher.GetHash(mhc);
const Index start = hasher.GetStart(hash, num_starts),
sortpos = Hasher::SortToStart(start),
bucket = Hasher::GetBucket(sortpos),
val = Hasher::GetIntraBucket(sortpos),
cval = hasher.Compress(val);
assert(bucket >= last_bucket);
assert(oneBitThresh || cval < Hasher::NoBumpThresh());
if (bucket != last_bucket) {
// moving to next bucket
sLOG << "Moving to bucket" << bucket << "was" << last_bucket;
if constexpr (oneBitThresh) {
unc_bump_cache.clear();
last_val = val;
}
if (thresh == Hasher::NoBumpThresh()) {
sLOG << "Bucket" << last_bucket << "has no bumped items";
bs->SetMeta(last_bucket, thresh);
}
all_good = true;
last_bucket = bucket;
thresh = Hasher::NoBumpThresh(); // maximum == "no bumpage"
last_cval = cval;
bump_cache.clear();
} else if (!all_good) {
// direct hard bump
sLOG << "Directly bumping" << tlx::wrap_unprintable(*(begin + i))
<< "from bucket" << bucket << "val" << val << cval << "start"
<< start << "sort" << sortpos << "hash" << std::hex << hash
<< std::dec;
bump_vec->push_back(*(begin + i));
continue;
} else if (cval != last_cval) {
// clear bump cache
sLOG << "Bucket" << bucket << "cval" << cval << "!=" << last_cval;
bump_cache.clear();
last_cval = cval;
}
if constexpr (oneBitThresh) {
// split into constexpr and normal if because unc_bump_cache isn't a
// vector if !oneBitThresh
if (val != last_val) {
unc_bump_cache.clear();
last_val = val;
}
}
const CoeffRow cr = hasher.GetCoeffs(hash);
const ResultRow rr = Hasher::kIsFilter
? hasher.GetResultRowFromHash(hash)
: hasher.GetResultRowFromInput(*(begin + i));
auto [success, row] = BandingAdd<kFCA1>(bs, start, cr, rr);
if (!success) {
assert(all_good);
if (bump_vec == nullptr)
// bumping disabled, abort!
return false;
// if we got this far, this is the first failure in this bucket,
// and we need to undo insertions with the same cval
sLOG << "First failure in bucket" << bucket << "val" << val
<< "start" << start << "sort" << sortpos << "hash" << std::hex
<< hash << "data" << (uint64_t)rr << std::dec << "for item"
<< tlx::wrap_unprintable(*(begin + i)) << "-> threshold"
<< cval << "clash in row" << row;
thresh = cval;
if constexpr (oneBitThresh) {
if (cval == 2) {
sLOG << "First failure in bucket" << bucket << "val" << val
<< "is a 'plus' case (below threshold)";
// "plus" case: store uncompressed threshold in hash table
hasher.Set(bucket, val);
// Set meta to 0 (some bumpage) but keep thresh at 2 so that
// we don't retroactively bump everything when moving to the
// next bucket
bs->SetMeta(bucket, 0);
all_good = false;
// bump all items with the same uncompressed value
do_bump(unc_bump_cache);
sLOG << "Also bumping" << tlx::wrap_unprintable(*(begin + i));
bump_vec->push_back(*(begin + i));
// proceed to next item, don't do regular bumping (but only
// if cval == 2, not generally!)
continue;
}
}
bs->SetMeta(bucket, thresh);
all_good = false;
do_bump(bump_cache);
bump_vec->push_back(*(begin + i));
} else {
sLOG << "Insertion succeeded of item"
<< tlx::wrap_unprintable(*(begin + i)) << "in pos" << row
<< "bucket" << bucket << "val" << val << cval << "start"
<< start << "sort" << sortpos << "hash" << std::hex << hash
<< "data" << (uint64_t)rr << std::dec;
bump_cache.emplace_back(row, i);
if constexpr (oneBitThresh) {
// also record in bump cache for uncompressed values
unc_bump_cache.emplace_back(row, i);
}
}
}
// set final threshold
if (thresh == Hasher::NoBumpThresh()) {
bs->SetMeta(last_bucket, thresh);
}
// migrate thresholds to hash table
if constexpr (oneBitThresh) {
hasher.Finalise(num_buckets);
}
LOGC(log) << "\tActual insertion took " << timer.ElapsedNanos(true) / 1e6
<< "ms";
return true;
}
} // namespace ribbon