https://github.com/lorenzhs/BuRR
Tip revision: 1c62832ad7d6eab5b337f386955868c3ce9a54ea authored by Lorenz Hübschle-Schneider on 11 September 2021, 12:56:53 UTC
README: paper link, bibtex
README: paper link, bibtex
Tip revision: 1c62832
storage.hpp
// Copyright (c) Lorenz Hübschle-Schneider
// Copyright (c) Facebook, Inc. and its affiliates.
// All Rights Reserved. This source code is licensed under the Apache 2.0
// License (found in the LICENSE file in the root directory).
#pragma once
#include "config.hpp"
#include <cassert>
#include <memory>
namespace ribbon {
namespace {
template <typename Config>
class MetaStorage {
public:
IMPORT_RIBBON_CONFIG(Config);
static constexpr bool debug = true;
// TODO: if Config::kSparseCoeffs, we don't need to store the least
// significant bits of the threshold because the start positions (and ergo
// thresholds) are aligned to log(#set bits) or something like that
static constexpr unsigned meta_bits = thresh_meta_bits<Config>;
using meta_t = at_least_t<meta_bits>;
static constexpr unsigned meta_t_bits = 8u * sizeof(meta_t);
// whether meta_bits divides meta_t_bits cleanly, i.e., we always only need
// one meta_t to get a bucket's metadata
static constexpr bool div_clean = (meta_t_bits % meta_bits) == 0;
using fetch_t =
std::conditional_t<div_clean, meta_t, at_least_t<2 * meta_t_bits>>;
static constexpr unsigned fetch_bits = 8u * sizeof(fetch_t),
items_per_fetch = fetch_bits / meta_bits;
static constexpr fetch_t extractor_mask = (fetch_t{1} << meta_bits) - 1;
static constexpr Index shift_mask = items_per_fetch - 1;
void Prepare(size_t num_slots) {
assert(num_slots >= kCoeffBits);
num_slots_ = num_slots;
Index num_starts = num_slots - kCoeffBits + 1;
num_buckets_ = (num_starts + kBucketSize - 1) / kBucketSize;
// +!div_clean at the end so we don't fetch beyond the bounds, even if
// we don't use it
size_t size = GetMetaSize() + !div_clean;
sLOGC(Config::log) << "Meta: allocating" << size << "entries of"
<< sizeof(meta_t) << "Bytes each";
meta_ = std::make_unique<meta_t[]>(size);
if constexpr (kThreshMode == ThreshMode::onebit) {
assert(size == (num_buckets_ + 7) / 8);
} else if constexpr (kThreshMode == ThreshMode::twobit) {
assert(size == (num_buckets_ + 3) / 4);
}
}
void Reset() {
meta_.reset();
}
inline void PrefetchMeta(Index bucket) const {
const Index fetch_bucket = div_clean ? bucket / items_per_fetch
: bucket * meta_bits / meta_t_bits;
__builtin_prefetch(meta_.get() + fetch_bucket,
/* rw */ 0, /* locality */ 1);
}
inline meta_t GetMeta(Index bucket) const {
assert(bucket < num_buckets_);
if constexpr (div_clean) {
const meta_t fetch = meta_[bucket / items_per_fetch];
const unsigned shift = meta_bits * (bucket & shift_mask);
assert(shift < fetch_bits);
return (fetch >> shift) & extractor_mask;
} else {
// find the fetch position first
Index start_bit = bucket * meta_bits;
Index fetch_bucket = start_bit / meta_t_bits;
fetch_t fetch;
memcpy(&fetch, meta_.get() + fetch_bucket, sizeof(fetch_t));
// start_bit now indicates which bits of 'fetch' we need
start_bit -= fetch_bucket * meta_t_bits;
return (fetch >> start_bit) & extractor_mask;
}
}
inline void SetMeta(Index bucket, meta_t val) {
assert(bucket < num_buckets_);
assert(val <= extractor_mask);
if constexpr (div_clean) {
const Index pos = bucket / items_per_fetch;
const unsigned shift = meta_bits * (bucket & shift_mask);
meta_[pos] &= ~static_cast<meta_t>(extractor_mask << shift);
meta_[pos] |= (val << shift);
} else {
// find the fetch position first
Index start_bit = bucket * meta_bits;
Index fetch_bucket = start_bit / meta_t_bits;
// start_bit now indicates which bits of 'fetch' we need
start_bit -= fetch_bucket * meta_t_bits;
fetch_t* fetch =
reinterpret_cast<fetch_t*>(meta_.get() + fetch_bucket);
*fetch &= ~(extractor_mask << start_bit);
*fetch |= (val << start_bit);
}
assert(GetMeta(bucket) == val);
}
// invalidates other->meta_!
template <typename Other>
void MoveMetadata(Other* other) {
assert(num_buckets_ == other->num_buckets_);
meta_.swap(other->meta_);
}
// clang-format off
inline Index GetNumSlots() const { return num_slots_; }
inline Index GetNumStarts() const { return num_slots_ - kCoeffBits + 1; }
inline Index GetNumBuckets() const { return num_buckets_; }
// clang-format on
size_t Size() const {
const size_t meta_bytes = GetMetaSize() * sizeof(meta_t);
sLOGC(Config::log) << "\tmeta size: " << num_buckets_ << "*"
<< meta_bits << "bits ->" << meta_bytes << "Bytes";
return meta_bytes + 2 * sizeof(Index) /* don't count num_buckets */;
}
protected:
size_t GetMetaSize() const {
return (num_buckets_ * meta_bits + meta_t_bits - 1) / meta_t_bits;
}
// num_buckets_ is for debugging only & can be recomputed easily
Index num_slots_ = 0, num_buckets_ = 0;
std::unique_ptr<meta_t[]> meta_;
};
} // namespace
template <typename Config>
class BasicStorage : public MetaStorage<Config> {
public:
IMPORT_RIBBON_CONFIG(Config);
using Super = MetaStorage<Config>;
BasicStorage() = default;
explicit BasicStorage(Index num_slots) {
if (num_slots > 0)
Prepare(num_slots);
}
void Prepare(size_t num_slots) {
Super::Prepare(num_slots);
coeffs_ = std::make_unique<CoeffRow[]>(num_slots);
results_ = std::make_unique<ResultRow[]>(num_slots);
}
void Reset() {
coeffs_.reset();
results_.reset();
Super::Reset();
}
inline void PrefetchQuery(Index i) const {
__builtin_prefetch(&results_[i], /* rw */ 0, /* locality */ 1);
}
inline CoeffRow GetCoeffs(Index row) const {
return coeffs_[row];
}
inline void SetCoeffs(Index row, CoeffRow val) {
coeffs_[row] = val;
}
inline ResultRow GetResult(Index row) const {
return results_[row];
}
inline void SetResult(Index row, ResultRow val) {
results_[row] = val;
}
// dummy interface
using State = Index;
inline State PrepareGetResult(Index row) const {
return row;
}
inline ResultRow GetFromState(const State& state) const {
return GetResult(state);
}
inline State AdvanceState(State state) const {
return state + 1;
}
template <typename Iterator, typename Hasher, typename Callback>
void AddRange(Iterator begin, Iterator end, const Hasher& hasher,
Callback bump_callback) {
BandingAddRange(this, hasher, begin, end, bump_callback);
}
size_t Size() const {
return Super::num_slots_ * sizeof(ResultRow) + Super::Size();
}
protected:
std::unique_ptr<CoeffRow[]> coeffs_;
std::unique_ptr<ResultRow[]> results_;
};
// only for backsubstition, can't be used for AddRange
template <typename Config>
class InterleavedSolutionStorage : public MetaStorage<Config> {
public:
IMPORT_RIBBON_CONFIG(Config);
using Super = MetaStorage<Config>;
InterleavedSolutionStorage() = default;
explicit InterleavedSolutionStorage(Index num_slots) {
if (num_slots > 0)
Prepare(num_slots);
}
void Prepare(size_t num_slots) {
Super::Prepare(num_slots);
size_t size = GetNumSegments() * sizeof(CoeffRow);
data_ = std::make_unique<unsigned char[]>(size);
}
void PrefetchQuery(Index segment_num) const {
__builtin_prefetch(data_.get() + segment_num * sizeof(CoeffRow),
/* rw */ 0, /* locality */ 1);
}
inline CoeffRow GetSegment(Index segment_num) const {
CoeffRow result;
memcpy(&result, data_.get() + segment_num * sizeof(CoeffRow),
sizeof(CoeffRow));
return result;
// return *reinterpret_cast<CoeffRow *>(data_.get() +
// segment_num * sizeof(CoeffRow));
}
inline void SetSegment(Index segment_num, CoeffRow val) {
memcpy(data_.get() + segment_num * sizeof(CoeffRow), &val,
sizeof(CoeffRow));
// *reinterpret_cast<CoeffRow *>(data_.get() +
// segment_num * sizeof(CoeffRow)) = val;
}
// clang-format off
inline Index GetNumBlocks() const { return Super::num_slots_ / kCoeffBits; }
inline Index GetNumSegments() const { return kResultBits * GetNumBlocks(); }
// clang-format on
size_t Size() const {
return GetNumSegments() * sizeof(CoeffRow) + Super::Size();
};
protected:
std::unique_ptr<unsigned char[]> data_;
};
// For now, two-bit thresholds only
template <typename Config>
class CacheLineStorage {
public:
IMPORT_RIBBON_CONFIG(Config);
static constexpr Index
// bits to store an entire bucket, incl metadata (-> adjust bucket size)
bucketbits = kBucketSize * kResultBits,
// yikes, use fake larger CL if a bucket wouldn't fit
clbits = (bucketbits > 512) ? bucketbits : 512, clsize = clbits / 8u,
buckets_per_cl = clbits / bucketbits,
items_per_row = 8u * sizeof(ResultRow) / kResultBits,
meta_bits_per_bucket = thresh_meta_bits<Config>,
// TODO this could be refined to pack the items, currently we round to
// bytes (also might fail for meta_bits_per_bucket > 8 if buckets_per_cl
// > 1 but when is that ever the case?)
metabytes_per_cl = buckets_per_cl * (meta_bits_per_bucket + 7) / 8,
metarows_per_cl =
(metabytes_per_cl + sizeof(ResultRow) - 1) / sizeof(ResultRow),
items_per_cl = (clbits - 8u * metabytes_per_cl) / kResultBits;
static constexpr bool should_use_compression =
tlx::integer_log2_ceil(kBucketSize) * buckets_per_cl > kResultBits;
// static_assert(should_use_compression == (kThreshMode != ThreshMode::normal),
// "bad config: check kThreshMode");
static constexpr ResultRow maxval = (1ul << kResultBits) - 1;
using meta_t = at_least_t<8 * metabytes_per_cl>;
using meta_item_t = at_least_t<meta_bits_per_bucket>;
// haven't implemented more than 128 meta-bits for lack of a larger type,
// would need to do indexing
static_assert(metabytes_per_cl <= 16, "not implemented");
static_assert(!(buckets_per_cl > 1 && meta_bits_per_bucket > 8),
"not implemented");
static constexpr bool debug = false;
CacheLineStorage() = default;
explicit CacheLineStorage(Index num_slots) {
// loudly warn about suboptimal config choice
if constexpr (should_use_compression !=
(kThreshMode != ThreshMode::normal)) {
sLOG1 << "WARNING: CacheLineStorage disagrees about your choice of "
"threshold compressor:"
<< (should_use_compression
? "SHOULD use compression but isn't"
: "SHOULD NOT use compression but is")
<< "kThreshMode =" << (int)kThreshMode
<< "uncompressed thresholds would need"
<< (tlx::integer_log2_ceil(kBucketSize) * buckets_per_cl)
<< "bits, have" << kResultBits << "bits for thresholds";
}
if (num_slots > 0)
Prepare(num_slots);
}
void Prepare(size_t num_slots) {
num_slots_ = num_slots;
size_t cls = ((num_slots + items_per_cl - 1) / items_per_cl);
size_ = cls * clsize;
sLOGC(Config::log) << "Preparing for" << num_slots << "slots @"
<< items_per_cl
<< "items per cl, kResultBits =" << kResultBits
<< "with" << metabytes_per_cl << "B/CL metainf ->"
<< cls << "CLs," << size_ << "rows,"
<< GetNumBuckets() << "buckets; efficiency:"
<< items_per_cl * 1.0 / (clbits / kResultBits)
<< "bucketbits =" << bucketbits << "clbits =" << clbits
<< "thresh mode" << (int)kThreshMode;
data_ = std::make_unique<unsigned char[]>(size_);
}
inline meta_item_t GetMeta(Index bucket) const {
assert(bucket < GetNumBuckets());
// clbits / kResultBits = #items that fit into a cacheline,
// but there's also metadata to account for so divide by
// items_per_cl again to get the same cache line as the bucket's items
const Index mapped_bucket =
(bucket * (clbits / kResultBits)) / items_per_cl;
// now account for potentially more than one bucket per CL
const Index cl = mapped_bucket / buckets_per_cl;
meta_t meta_block;
memcpy(&meta_block, data_.get() + cl * clsize, sizeof(meta_t));
if constexpr (kThreshMode == ThreshMode::normal && buckets_per_cl == 1) {
return meta_block;
} else {
const auto shift =
meta_bits_per_bucket * (mapped_bucket & (buckets_per_cl - 1));
const auto mask = (meta_item_t{1} << meta_bits_per_bucket) - 1;
sLOG << "GetMeta" << bucket << "-> cl" << cl << " idx"
<< cl * clsize << "shift" << shift << "mask" << mask;
return static_cast<meta_item_t>(meta_block >> shift) & mask;
}
}
inline void SetMeta(Index bucket, meta_item_t val) const {
assert(bucket < GetNumBuckets());
const Index mapped_bucket =
(bucket * (clbits / kResultBits)) / items_per_cl;
const Index cl = mapped_bucket / buckets_per_cl, idx = cl * clsize;
meta_t* ptr = reinterpret_cast<meta_t*>(data_.get() + idx);
if constexpr (kThreshMode == ThreshMode::normal && buckets_per_cl == 1) {
*ptr = val;
} else {
const auto shift =
meta_bits_per_bucket * (mapped_bucket & (buckets_per_cl - 1));
const auto mask = (meta_item_t{1} << meta_bits_per_bucket) - 1;
assert(val <= mask);
sLOG << "SetMeta" << bucket << "val" << (int)val << "-> cl" << cl
<< "idx" << idx << "shift" << shift << "mask" << mask;
// first clear, then write
*ptr &= ~static_cast<meta_t>(mask << shift);
*ptr |= (static_cast<meta_t>(val) << shift);
}
assert(GetMeta(bucket) == val);
}
void PrefetchQuery(Index row) const {
const Index cl = row / items_per_cl;
const Index cl_start = cl * clsize;
__builtin_prefetch(data_.get() + cl_start, /* rw */ 0, /* locality */ 1);
}
void PrefetchMeta(Index) const {
// nothing to do, PrefetchQuery already does everything we need
}
inline ResultRow GetResult(Index row) const {
const Index cl = row / items_per_cl;
const Index row_in_cl = row - cl * items_per_cl;
const Index offset = row_in_cl / items_per_row + metarows_per_cl;
const Index data_row = cl * clsize + offset * sizeof(ResultRow);
sLOG << "Get" << row << "cl" << cl << "offset" << offset << "in row"
<< data_row;
const auto shift = (row_in_cl % items_per_row) * kResultBits;
ResultRow result;
memcpy(&result, data_.get() + data_row, sizeof(ResultRow));
return (result >> shift) & maxval;
}
using State =
std::conditional_t<items_per_row == 1, Index, std::pair<Index, Index>>;
inline State PrepareGetResult(Index row) const {
const Index cl = row / items_per_cl;
const Index row_in_cl = row - cl * items_per_cl;
const Index offset = row_in_cl / items_per_row + metarows_per_cl;
const Index data_row = cl * clsize + offset * sizeof(ResultRow);
if constexpr (items_per_row == 1) {
sLOG << "Prep" << row << "cl" << cl << "offset" << offset
<< "in row" << data_row;
return data_row;
} else {
const auto shift = (row_in_cl % items_per_row) * kResultBits;
sLOG << "Prep" << row << "cl" << cl << "offset" << offset
<< "in row" << data_row << "/" << shift;
return std::make_pair(data_row, shift);
}
}
inline ResultRow GetFromState(const State& state) const {
ResultRow result;
if constexpr (items_per_row == 1) {
memcpy(&result, data_.get() + state, sizeof(ResultRow));
return result;
} else {
auto [data_row, shift] = state;
memcpy(&result, data_.get() + data_row, sizeof(ResultRow));
return (result >> shift) & maxval;
}
}
inline State AdvanceState(State state) const {
Index data_row;
if constexpr (items_per_row != 1) {
Index shift;
std::tie(data_row, shift) = state;
if (shift + kResultBits < 8u * sizeof(ResultRow)) {
// same row, just shift more
sLOG << "Adv" << state << "staying in row, new shift"
<< shift + kResultBits;
return std::make_pair(data_row, shift + kResultBits);
}
} else {
data_row = state;
}
Index new_row = data_row + sizeof(ResultRow);
// Skip metadata of next cache line
if (new_row % clsize == 0) {
new_row += metarows_per_cl * sizeof(ResultRow);
}
// new_row += metarows_per_cl * sizeof(ResultRow) * (new_row % clsize == 0);
if constexpr (items_per_row == 1)
return new_row;
else
return std::make_pair(new_row, 0);
}
inline void SetResult(Index row, ResultRow val) {
const Index cl = row / items_per_cl;
const Index row_in_cl = row - cl * items_per_cl;
const Index offset = row_in_cl / items_per_row + metarows_per_cl;
const Index data_row = cl * clsize + offset * sizeof(ResultRow);
sLOG << "Set" << row << "to" << (int)val << "cl" << cl << "offset"
<< offset << "in row" << data_row;
const auto shift = (row_in_cl % items_per_row) * kResultBits;
ResultRow* ptr = reinterpret_cast<ResultRow*>(data_.get() + data_row);
*ptr &= ~static_cast<ResultRow>(maxval << shift);
*ptr |= (val << shift);
}
// clang-format off
inline Index GetNumSlots() const { return num_slots_; }
inline Index GetNumStarts() const { return num_slots_ - kCoeffBits + 1; }
inline Index GetNumBuckets() const { return (GetNumStarts() + kBucketSize - 1) / kBucketSize; }
// clang-format on
size_t Size() const {
assert(size_ == ((num_slots_ + items_per_cl - 1) / items_per_cl) * clsize);
return size_ + 2 * sizeof(Index);
}
protected:
size_t size_;
Index num_slots_;
std::unique_ptr<unsigned char[]> data_;
};
} // namespace ribbon