https://github.com/halide/Halide
Raw File
Tip revision: 72fa5cce73c00a5e640a620e6a0d3778fb1e6460 authored by Volodymyr Kysenko on 27 February 2023, 23:56:32 UTC
Fix formatting
Tip revision: 72fa5cc
gpu_context_common.h
#include "printer.h"
#include "scoped_mutex_lock.h"

namespace Halide {
namespace Internal {

template<typename ContextT, typename ModuleStateT>
class GPUCompilationCache {
    struct CachedCompilation {
        ContextT context{};
        ModuleStateT module_state{};
        uint32_t kernel_id{};
        uint32_t use_count{0};

        CachedCompilation(ContextT context, ModuleStateT module_state,
                          uint32_t kernel_id, uint32_t use_count)
            : context(context), module_state(module_state),
              kernel_id(kernel_id), use_count(use_count) {
        }
    };

    halide_mutex mutex;

    static constexpr float kLoadFactor{.5f};
    static constexpr int kInitialTableBits{7};
    int log2_compilations_size{0};  // number of bits in index into compilations table.
    CachedCompilation *compilations{nullptr};
    int count{0};

    static constexpr uint32_t kInvalidId{0};
    static constexpr uint32_t kDeletedId{1};

    uint32_t unique_id{2};  // zero is an invalid id

public:
    static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits) {
        uintptr_t addr = (uintptr_t)context + id;
        // Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
        // in hexadecimal.
        if (sizeof(uintptr_t) >= 8) {
            return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
        } else {
            return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
        }
    }

    HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
        if (log2_compilations_size == 0) {
            if (!resize_table(kInitialTableBits)) {
                return false;
            }
        }
        if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
            if (!resize_table(log2_compilations_size + 1)) {
                return false;
            }
        }
        count += 1;
        uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
        for (int i = 0; i < (1 << log2_compilations_size); i++) {
            uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
            if (compilations[effective_index].kernel_id <= kDeletedId) {
                compilations[effective_index] = entry;
                return true;
            }
        }
        // This is a logic error that should never occur. It means the table is
        // full, but it should have been resized.
        halide_debug_assert(nullptr, false);
        return false;
    }

    HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id,
                                              ModuleStateT *&module_state, int increment) {
        if (log2_compilations_size == 0) {
            return false;
        }
        uintptr_t index = kernel_hash(context, id, log2_compilations_size);
        for (int i = 0; i < (1 << log2_compilations_size); i++) {
            uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);

            if (compilations[effective_index].kernel_id == kInvalidId) {
                return false;
            }
            if (compilations[effective_index].context == context &&
                compilations[effective_index].kernel_id == id) {
                module_state = &compilations[effective_index].module_state;
                if (increment != 0) {
                    compilations[effective_index].use_count += increment;
                }
                return true;
            }
        }
        return false;
    }

    HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
        ScopedMutexLock lock_guard(&mutex);
        uint32_t id = (uint32_t)(uintptr_t)state_ptr;
        ModuleStateT *mod_ptr;
        if (find_internal(context, id, mod_ptr, 0)) {
            module_state = *mod_ptr;
            return true;
        }
        return false;
    }

    HALIDE_MUST_USE_RESULT bool resize_table(int size_bits) {
        if (size_bits != log2_compilations_size) {
            int new_size = (1 << size_bits);
            int old_size = (1 << log2_compilations_size);
            CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
            if (new_table == nullptr) {
                // signal error.
                return false;
            }
            memset(new_table, 0, new_size * sizeof(CachedCompilation));
            CachedCompilation *old_table = compilations;
            compilations = new_table;
            log2_compilations_size = size_bits;

            if (count > 0) {  // Mainly to catch empty initial table case
                for (int32_t i = 0; i < old_size; i++) {
                    if (old_table[i].kernel_id != kInvalidId &&
                        old_table[i].kernel_id != kDeletedId) {
                        bool result = insert(old_table[i]);
                        halide_debug_assert(nullptr, result);  // Resizing the table while resizing the table is a logic error.
                        (void)result;
                    }
                }
            }
            free(old_table);
        }
        return true;
    }

    template<typename FreeModuleT>
    void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
        if (count == 0) {
            return;
        }

        for (int i = 0; i < (1 << log2_compilations_size); i++) {
            if (compilations[i].kernel_id > kInvalidId &&
                (all || (compilations[i].context == context)) &&
                compilations[i].use_count == 0) {
                debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
                                    << " id " << compilations[i].kernel_id
                                    << " context " << compilations[i].context << "\n";
                f(compilations[i].module_state);
                compilations[i].module_state = nullptr;
                compilations[i].kernel_id = kDeletedId;
                count--;
            }
        }
    }

    template<typename FreeModuleT>
    void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
        ScopedMutexLock lock_guard(&mutex);

        release_context(user_context, false, context, f);
    }

    template<typename FreeModuleT>
    void release_all(void *user_context, FreeModuleT &f) {
        ScopedMutexLock lock_guard(&mutex);

        release_context(user_context, true, nullptr, f);
        // Some items may have been in use, so can't free.
        if (count == 0) {
            free(compilations);
            compilations = nullptr;
            log2_compilations_size = 0;
        }
    }

    template<typename CompileModuleT, typename... Args>
    HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr,
                                                   ContextT context, ModuleStateT &result,
                                                   CompileModuleT f,
                                                   Args... args) {
        ScopedMutexLock lock_guard(&mutex);

        uint32_t *id_ptr = (uint32_t *)state_ptr;
        if (*id_ptr == 0) {
            *id_ptr = unique_id++;
        }

        ModuleStateT *mod;
        if (find_internal(context, *id_ptr, mod, 1)) {
            result = *mod;
            return true;
        }

        // TODO(zvookin): figure out the calling signature here...
        ModuleStateT compiled_module = f(args...);
        debug(user_context) << "Caching compiled kernel: " << compiled_module
                            << " id " << *id_ptr << " context " << context << "\n";
        if (compiled_module == nullptr) {
            return false;
        }

        if (!insert({context, compiled_module, *id_ptr, 1})) {
            return false;
        }
        result = compiled_module;

        return true;
    }

    void release_hold(void *user_context, ContextT context, void *state_ptr) {
        ModuleStateT *mod;
        uint32_t id = (uint32_t)(uintptr_t)state_ptr;
        bool result = find_internal(context, id, mod, -1);
        halide_debug_assert(user_context, result);  // Value must be in cache to be released
        (void)result;
    }
};

}  // namespace Internal
}  // namespace Halide
back to top