https://github.com/shader-slang/slang
Tip revision: f44da6cc5c0f211c13bd1eb0743d79c7861ea64e authored by Yong He on 09 February 2024, 02:29:32 UTC
Support pointers in SPIRV. (#3561)
Support pointers in SPIRV. (#3561)
Tip revision: f44da6c
slang-dictionary.h
#ifndef SLANG_CORE_DICTIONARY_H
#define SLANG_CORE_DICTIONARY_H
#include "slang-list.h"
#include "slang-linked-list.h"
#include "slang-common.h"
#include "slang-uint-set.h"
#include "slang-exception.h"
#include "slang-math.h"
#include "slang-hash.h"
#include "../../external/unordered_dense/include/ankerl/unordered_dense.h"
#include <initializer_list>
namespace Slang
{
template<typename TKey, typename TValue>
class KeyValuePair
{
public:
TKey key;
TValue value;
KeyValuePair()
{}
KeyValuePair(const TKey& inKey, const TValue& inValue)
{
key = inKey;
value = inValue;
}
KeyValuePair(TKey&& inKey, TValue&& inValue)
{
key = _Move(inKey);
value = _Move(inValue);
}
KeyValuePair(TKey&& inKey, const TValue& inValue)
{
key = _Move(inKey);
value = inValue;
}
KeyValuePair(const KeyValuePair<TKey, TValue>& that)
{
key = that.key;
value = that.value;
}
KeyValuePair(KeyValuePair<TKey, TValue>&& that)
{
operator=(_Move(that));
}
KeyValuePair& operator=(KeyValuePair<TKey, TValue>&& that)
{
key = _Move(that.key);
value = _Move(that.value);
return *this;
}
KeyValuePair& operator=(const KeyValuePair<TKey, TValue>& that)
{
key = that.key;
value = that.value;
return *this;
}
HashCode getHashCode() const
{
return combineHash(
Slang::getHashCode(key),
Slang::getHashCode(value));
}
bool operator==(const KeyValuePair<TKey, TValue>& that) const
{
return (key == that.key) && (value == that.value);
}
};
template<typename TKey, typename TValue>
inline KeyValuePair<TKey, TValue> KVPair(const TKey& k, const TValue& v)
{
return KeyValuePair<TKey, TValue>(k, v);
}
namespace KeyValueDetail {
template <typename KEY, typename VALUE>
SLANG_FORCE_INLINE const KEY* getKey(const std::pair<KEY, VALUE>* in)
{
return &in->first;
}
template <typename KEY, typename VALUE>
SLANG_FORCE_INLINE const KEY* getKey(const KeyValuePair<KEY, VALUE>* in)
{
return &in->key;
}
template <typename KEY, typename VALUE>
SLANG_FORCE_INLINE const VALUE* getValue(const std::pair<KEY, VALUE>* in)
{
return &in->second;
}
template <typename KEY, typename VALUE>
SLANG_FORCE_INLINE const VALUE* getValue(const KeyValuePair<KEY, VALUE>* in)
{
return &in->value;
}
} // namespace KeyValueDetail
const float kMaxLoadFactor = 0.7f;
template<typename TKey, typename TValue, typename Hash = Slang::Hash<TKey>, typename KeyEqual = std::equal_to<TKey>>
class Dictionary
{
using InnerMap = ankerl::unordered_dense::map<
TKey,
TValue,
Hash,
KeyEqual>;
using ThisType = Dictionary<TKey, TValue, Hash, KeyEqual>;
InnerMap map;
public:
Dictionary() = default;
Dictionary(const Dictionary&) = default;
Dictionary(Dictionary&&) = default;
ThisType& operator=(const ThisType&) = default;
ThisType& operator=(ThisType&&) = default;
Dictionary(std::initializer_list<typename InnerMap::value_type> inits)
: map(std::move(inits))
{
}
//
// Types
//
using Iterator = typename InnerMap::iterator;
using ConstIterator = typename InnerMap::const_iterator;
using KeyType = TKey;
using ValueType = TValue;
//
// Iterators
//
auto begin() { return map.begin(); }
auto begin() const { return map.begin(); }
auto end() { return map.end(); }
auto end() const { return map.end(); }
//
// Modifiers
//
// Removes all values from the map
void clear() { map.clear(); }
// Erases the value at the specified key if it exists
void remove(const TKey& key) { map.erase(key); }
// Reserves enough space for the specified number of values
void reserve(Index size) { map.reserve(std::size_t(size)); };
// Swap with another map
void swapWith(ThisType& rhs) { std::swap(*this, rhs); }
//
// Query capacity
//
std::size_t getCount() const { return map.size(); }
//
// Lookup
//
// Returns true if the map contains an equivalent key
template<typename K>
bool containsKey(const K& k) const { return map.contains(k); }
// Returns a valid pointer to the requested element, or nullptr if it
// doesn't exist
template<typename K>
const TValue* tryGetValue(const K& key) const
{
auto i = map.find(key);
return i == map.end() ? nullptr : &(i->second);
}
// Returns a valid pointer to the requested element, or nullptr if it
// doesn't exist
template<typename K>
TValue* tryGetValue(const K& key)
{
auto i = map.find(key);
return i == map.end() ? nullptr : std::addressof(i->second);
}
// Returns true and copies the element into 'value' if present.
// Otherwise returns false and value unmodified.
template<typename K>
bool tryGetValue(const K& key, TValue& value) const
{
auto i = map.find(key);
if(i == map.end())
return false;
value = i->second;
return true;
}
// Returns a const reference to the value at the given key. Asserts if
// the value doesn't exist
const TValue& getValue(const TKey& key) const
{
if(const auto x = tryGetValue(key))
return *x;
SLANG_ASSERT_FAILURE("The key does not exist in dictionary.");
}
// Returns a reference to the value at the given key. Asserts if the
// value doesn't exist
TValue& getValue(const TKey& key)
{
if(const auto x = tryGetValue(key))
return *x;
SLANG_ASSERT_FAILURE("The key does not exist in dictionary.");
}
//
// Combined Lookup and Insertion
//
// Tries to insert the given element, if a value was already present at
// the given key then returns a pointer to that element instead.
// Returns nullptr if insertion was successful.
TValue* tryGetValueOrAdd(const typename InnerMap::value_type& kvPair)
{
const auto& [iterator, inserted] = map.insert(kvPair);
return inserted ? nullptr : std::addressof(iterator->second);
}
// Tries to insert the given element, if a value was already present at
// the given key then returns a pointer to that element instead.
// Returns nullptr if insertion was successful.
TValue* tryGetValueOrAdd(typename InnerMap::value_type&& kvPair)
{
const auto& [iterator, inserted] = map.insert(std::move(kvPair));
return inserted ? nullptr : std::addressof(iterator->second);
}
// Tries to insert the given element, if a value was already present at
// the given key then returns a pointer to that element instead.
// Returns nullptr if insertion was successful.
TValue* tryGetValueOrAdd(const TKey& key, const TValue& value) { return tryGetValueOrAdd({key, value}); }
// Inserts the given value if it doesn't exist already
// Return a reference to the (possibly new) value in the map
TValue& getOrAddValue(const TKey& key, const TValue& defaultValue)
{
auto [iterator, inserted] = map.insert({key, defaultValue});
return iterator->second;
}
// Returns a reference to the value at the specified key, default
// initializing it if it doesn't already exist
TValue& operator[]( const TKey& key ) { return map[key]; }
// Returns a reference to the value at the specified key, default
// initializing it if it doesn't already exist
TValue& operator[]( TKey&& key ) { return map[std::move(key)]; }
//
// Insertion
//
// Returns true if the value was inserted, returns false if the map
// already has a value associated with this key
bool addIfNotExists(typename InnerMap::value_type&& kvPair) { return !tryGetValueOrAdd(std::move(kvPair)); }
// Returns true if the value was inserted, returns false if the map
// already has a value associated with this key
bool addIfNotExists(const typename InnerMap::value_type& kvPair) { return !tryGetValueOrAdd(kvPair); }
// Returns true if the value was inserted, returns false if the map
// already has a value associated with this key
bool addIfNotExists(const TKey& k, const TValue& v) { return addIfNotExists({k, v}); }
// Returns true if the value was inserted, returns false if the map
// already has a value associated with this key
bool addIfNotExists(TKey&& k, TValue&& v) { return addIfNotExists({std::move(k), std::move(v)}); }
// Asserts if the key already exists in the dictionary
void add(typename InnerMap::value_type&& kvPair)
{
if (!addIfNotExists(std::move(kvPair)))
SLANG_ASSERT_FAILURE("The key already exists in Dictionary.");
}
// Asserts if the key already exists in the dictionary
void add(const typename InnerMap::value_type& kvPair)
{
if (!addIfNotExists(kvPair))
SLANG_ASSERT_FAILURE("The key already exists in Dictionary.");
}
// Asserts if the key already exists in the dictionary
void add(const TKey& key, const TValue& value) { add({key, value}); }
// Asserts if the key already exists in the dictionary
void add(TKey&& key, TValue&& value) { add({std::move(key), std::move(value)}); }
// Inserts into the dictionary or assigns if the key already exists
void set(const TKey& key, const TValue& value) { map.insert_or_assign(key, value); }
};
/* We may want to rename this, as strictly speaking _Caps names are reserved */
class _DummyClass
{};
template<typename T, typename DictionaryType>
class HashSetBase
{
protected:
DictionaryType dict;
private:
template<typename... Args>
void init(const T& v, Args... args)
{
add(v);
init(args...);
}
public:
HashSetBase()
{}
template<typename Arg, typename... Args>
HashSetBase(Arg arg, Args... args)
{
init(arg, args...);
}
HashSetBase(const HashSetBase& set)
{
operator=(set);
}
HashSetBase(HashSetBase&& set)
{
operator=(_Move(set));
}
HashSetBase& operator=(const HashSetBase& set)
{
dict = set.dict;
return *this;
}
HashSetBase& operator=(HashSetBase&& set)
{
dict = _Move(set.dict);
return *this;
}
public:
class Iterator
{
private:
typename DictionaryType::ConstIterator iter;
public:
Iterator() = default;
const T& operator*() const
{
return *KeyValueDetail::getKey(std::addressof(*iter));
}
const T* operator->() const
{
return KeyValueDetail::getKey(std::addressof(*iter));
}
Iterator& operator++()
{
++iter;
return *this;
}
Iterator operator++(int)
{
Iterator rs = *this;
operator++();
return rs;
}
bool operator!=(const Iterator& that) const
{
return iter != that.iter;
}
bool operator==(const Iterator& that) const
{
return iter == that.iter;
}
Iterator(const typename DictionaryType::ConstIterator& _iter)
{
this->iter = _iter;
}
};
Iterator begin() const
{
return Iterator(dict.begin());
}
Iterator end() const
{
return Iterator(dict.end());
}
public:
auto getCount() const
{
return dict.getCount();
}
void clear()
{
dict.clear();
}
bool add(const T& obj)
{
return dict.addIfNotExists(obj, _DummyClass());
}
bool add(T&& obj)
{
return dict.addIfNotExists(_Move(obj), _DummyClass());
}
void remove(const T& obj)
{
dict.remove(obj);
}
bool contains(const T& obj) const
{
return dict.containsKey(obj);
}
};
template <typename T>
class HashSet : public HashSetBase<T, Dictionary<T, _DummyClass>>
{};
template <typename TKey, typename TValue>
class OrderedDictionary
{
friend class Iterator;
friend class ItemProxy;
private:
inline int getProbeOffset(int /*probeIdx*/) const
{
// quadratic probing
return 1;
}
private:
int m_bucketCountMinusOne;
int m_count;
UIntSet m_marks;
LinkedList<KeyValuePair<TKey, TValue>> m_kvPairs;
LinkedNode<KeyValuePair<TKey, TValue>>** m_hashMap;
void deallocateAll()
{
if (m_hashMap)
delete[] m_hashMap;
m_hashMap = nullptr;
m_kvPairs.clear();
}
inline bool isDeleted(int pos) const { return m_marks.contains((pos << 1) + 1); }
inline bool isEmpty(int pos) const { return !m_marks.contains((pos << 1)); }
inline void setDeleted(int pos, bool val)
{
if (val)
m_marks.add((pos << 1) + 1);
else
m_marks.remove((pos << 1) + 1);
}
inline void setEmpty(int pos, bool val)
{
if (val)
m_marks.remove((pos << 1));
else
m_marks.add((pos << 1));
}
struct FindPositionResult
{
int objectPosition;
int insertionPosition;
FindPositionResult()
{
objectPosition = -1;
insertionPosition = -1;
}
FindPositionResult(int objPos, int insertPos)
{
objectPosition = objPos;
insertionPosition = insertPos;
}
};
template <typename T> inline int getHashPos(T& key) const
{
const unsigned int hash = (unsigned int)getHashCode(key);
return ((unsigned int)(hash * 2654435761)) % m_bucketCountMinusOne;
}
template <typename T> FindPositionResult findPosition(const T& key) const
{
int hashPos = getHashPos((T&)key);
int insertPos = -1;
int numProbes = 0;
while (numProbes <= m_bucketCountMinusOne)
{
if (isEmpty(hashPos))
{
if (insertPos == -1)
return FindPositionResult(-1, hashPos);
else
return FindPositionResult(-1, insertPos);
}
else if (isDeleted(hashPos))
{
if (insertPos == -1)
insertPos = hashPos;
}
else if (m_hashMap[hashPos]->value.key == key)
{
return FindPositionResult(hashPos, -1);
}
numProbes++;
hashPos = (hashPos + getProbeOffset(numProbes)) & m_bucketCountMinusOne;
}
if (insertPos != -1)
return FindPositionResult(-1, insertPos);
SLANG_ASSERT_FAILURE("Hash map is full. This indicates an error in Key::Equal or Key::GetHashCode.");
}
TValue& _insert(KeyValuePair<TKey, TValue>&& kvPair, int pos)
{
auto node = m_kvPairs.addLast();
node->value = _Move(kvPair);
m_hashMap[pos] = node;
setEmpty(pos, false);
setDeleted(pos, false);
return node->value.value;
}
void maybeRehash()
{
if (m_bucketCountMinusOne == -1 || m_count / (float)m_bucketCountMinusOne >= kMaxLoadFactor)
{
int newSize = (m_bucketCountMinusOne + 1) * 2;
if (newSize == 0)
{
newSize = 128;
}
OrderedDictionary<TKey, TValue> newDict;
newDict.m_bucketCountMinusOne = newSize - 1;
newDict.m_hashMap = new LinkedNode<KeyValuePair<TKey, TValue>>*[newSize];
newDict.m_marks.resizeAndClear(newSize * 2);
if (m_hashMap)
{
for (auto& kvPair : *this)
{
newDict.add(_Move(kvPair));
}
}
*this = _Move(newDict);
}
}
bool addIfNotExists(KeyValuePair<TKey, TValue>&& kvPair)
{
maybeRehash();
auto pos = findPosition(kvPair.key);
if (pos.objectPosition != -1)
return false;
else if (pos.insertionPosition != -1)
{
m_count++;
_insert(_Move(kvPair), pos.insertionPosition);
return true;
}
else
SLANG_ASSERT_FAILURE("Inconsistent find result returned. This is a bug in Dictionary implementation.");
}
void add(KeyValuePair<TKey, TValue>&& kvPair)
{
if (!addIfNotExists(_Move(kvPair)))
SLANG_ASSERT_FAILURE("The key already exists in Dictionary.");
}
TValue& set(KeyValuePair<TKey, TValue>&& kvPair)
{
maybeRehash();
auto pos = findPosition(kvPair.key);
if (pos.objectPosition != -1)
{
m_hashMap[pos.objectPosition]->removeAndDelete();
return _insert(_Move(kvPair), pos.objectPosition);
}
else if (pos.insertionPosition != -1)
{
m_count++;
return _insert(_Move(kvPair), pos.insertionPosition);
}
else
SLANG_ASSERT_FAILURE("Inconsistent find result returned. This is a bug in Dictionary implementation.");
}
public:
using Iterator = typename LinkedList<KeyValuePair<TKey, TValue>>::Iterator;
using ConstIterator = typename LinkedList<KeyValuePair<TKey, TValue>>::ConstIterator;
Iterator begin() { return m_kvPairs.begin(); }
Iterator end() { return m_kvPairs.end(); }
ConstIterator begin() const { return m_kvPairs.begin(); }
ConstIterator end() const { return m_kvPairs.end(); }
public:
void add(const TKey& key, const TValue& value)
{
add(KeyValuePair<TKey, TValue>(key, value));
}
void add(TKey&& key, TValue&& value)
{
add(KeyValuePair<TKey, TValue>(_Move(key), _Move(value)));
}
bool addIfNotExists(const TKey& key, const TValue& value)
{
return addIfNotExists(KeyValuePair<TKey, TValue>(key, value));
}
bool addIfNotExists(TKey&& key, TValue&& value)
{
return addIfNotExists(KeyValuePair<TKey, TValue>(_Move(key), _Move(value)));
}
void remove(const TKey& key)
{
if (m_count > 0)
{
auto pos = findPosition(key);
if (pos.objectPosition != -1)
{
m_kvPairs.removeAndDelete(m_hashMap[pos.objectPosition]);
m_hashMap[pos.objectPosition] = 0;
setDeleted(pos.objectPosition, true);
m_count--;
}
}
}
void clear()
{
m_count = 0;
m_kvPairs.clear();
m_marks.clear();
}
template <typename T> bool containsKey(const T& key) const
{
if (m_bucketCountMinusOne == -1)
return false;
auto pos = findPosition(key);
return pos.objectPosition != -1;
}
template <typename T> TValue* tryGetValue(const T& key) const
{
if (m_bucketCountMinusOne == -1)
return nullptr;
auto pos = findPosition(key);
if (pos.objectPosition != -1)
{
return &(m_hashMap[pos.objectPosition]->value.value);
}
return nullptr;
}
template <typename T> bool tryGetValue(const T& key, TValue& value) const
{
if (m_bucketCountMinusOne == -1)
return false;
auto pos = findPosition(key);
if (pos.objectPosition != -1)
{
value = m_hashMap[pos.objectPosition]->value.value;
return true;
}
return false;
}
class ItemProxy
{
private:
const OrderedDictionary<TKey, TValue>* dict;
TKey key;
public:
ItemProxy(const TKey& _key, const OrderedDictionary<TKey, TValue>* _dict)
{
this->dict = _dict;
this->key = _key;
}
ItemProxy(TKey&& _key, const OrderedDictionary<TKey, TValue>* _dict)
{
this->dict = _dict;
this->key = _Move(_key);
}
TValue& getValue() const
{
auto pos = dict->findPosition(key);
if (pos.objectPosition != -1)
{
return dict->m_hashMap[pos.objectPosition]->value.value;
}
else
{
SLANG_ASSERT_FAILURE("The key does not exists in dictionary.");
}
}
inline TValue& operator()() const { return getValue(); }
operator TValue&() const { return getValue(); }
TValue& operator=(const TValue& val)
{
return ((OrderedDictionary<TKey, TValue>*)dict)
->set(KeyValuePair<TKey, TValue>(_Move(key), val));
}
TValue& operator=(TValue&& val)
{
return ((OrderedDictionary<TKey, TValue>*)dict)
->set(KeyValuePair<TKey, TValue>(_Move(key), _Move(val)));
}
};
ItemProxy operator[](const TKey& key) const { return ItemProxy(key, this); }
ItemProxy operator[](TKey&& key) const { return ItemProxy(_Move(key), this); }
int getCount() const { return m_count; }
KeyValuePair<TKey, TValue>& getFirst() const { return m_kvPairs.getFirst(); }
KeyValuePair<TKey, TValue>& getLast() const { return m_kvPairs.getLast(); }
private:
template <typename... Args>
void init(const KeyValuePair<TKey, TValue>& kvPair, Args... args)
{
add(kvPair);
init(args...);
}
public:
OrderedDictionary()
{
m_bucketCountMinusOne = -1;
m_count = 0;
m_hashMap = 0;
}
template <typename Arg, typename... Args> OrderedDictionary(Arg arg, Args... args)
{
init(arg, args...);
}
OrderedDictionary(const OrderedDictionary<TKey, TValue>& other)
: m_bucketCountMinusOne(-1)
, m_count(0)
, m_hashMap(0)
{
*this = other;
}
OrderedDictionary(OrderedDictionary<TKey, TValue>&& other)
: m_bucketCountMinusOne(-1)
, m_count(0)
, m_hashMap(0)
{
*this = (_Move(other));
}
OrderedDictionary<TKey, TValue>&
operator=(const OrderedDictionary<TKey, TValue>& other)
{
if (this == &other)
return *this;
clear();
for (auto& item : other)
add(item.key, item.value);
return *this;
}
OrderedDictionary<TKey, TValue>&
operator=(OrderedDictionary<TKey, TValue>&& other)
{
if (this == &other)
return *this;
deallocateAll();
m_bucketCountMinusOne = other.m_bucketCountMinusOne;
m_count = other.m_count;
m_hashMap = other.m_hashMap;
m_marks = _Move(other.m_marks);
other.m_hashMap = 0;
other.m_count = 0;
other.m_bucketCountMinusOne = -1;
m_kvPairs = _Move(other.m_kvPairs);
return *this;
}
~OrderedDictionary() { deallocateAll(); }
};
template <typename T> class OrderedHashSet : public HashSetBase<T, OrderedDictionary<T, _DummyClass>>
{
public:
T& getLast()
{
return this->dict.getLast().key;
}
void removeLast()
{
this->remove(getLast());
}
};
}
#endif