Revision 97573c6e6f803a234be18e37648c3399537808fd authored by Volodymyr Kysenko on 27 October 2023, 21:21:26 UTC, committed by GitHub on 27 October 2023, 21:21:26 UTC
* Minimal hoist_storage plumbing

* HoistedStorage placeholder IR node

* Basic hoist_storage test

* Fully plumb through the HoistedStorage node

* IRPrinter for HoistedStorage

* Insert hoisted storage at the correct loop level

* Progress

* Formatted

* Move out common code for creating Allocate node

* Format

* Emit Allocate at the HoistedStorage site

* Collect all dependant vars

* Basic test working

* Progress

* Substitute lets into allocation extents instead of lifting stuff

* Infer bounds for the extends dependant on loop variables

* Update tests

* Remove old code

* Remove old code

* Better tests

* More tests

* Validate schedules with hoist_storage

* Error test

* Fix stupid mistake

* More tests

* Remove debug prints

* Better errors

* Add missing handler for inlined functions

* Format

* Comments

* Format

* Add some missing visit handlers

* New line

* Fix comment

* Luckily we only have two build systems

* Adds hoist_storage_root

* Comment for IR node

* Serialization support for HoistedStorage

* Handle hoist_storage fo tuples

* Handle multiple realize nodes

* Move assert up

* Better error message

* Better loop bounds

* Format

* Updated error message

* Happy clang-tidy happy me

* An error message when compute is inlined, but store is not inlined

* Only mutate lets which are needed

* Update apps to use hoist_storage

Some very minor performance gains, but mostly in the noise.

Also switched the apps makefiles to emit stmt html by default instead of
stmt, to take advantage of the new and improved stmt html.

* Switch to stack of hoisted storages

* Limit scope of lets for expansion

* Break early

* Skip substitute_in_all_lets

* Re-use expanded min/extents

* WebAssembly JIT does not support custom allocators

* Change debug level to get more info about segfault

* More debugging prints

* Let's try aligned malloc

* Revert "Change debug level to get more info about segfault"

This reverts commit a5a689be8c6ad351674f3ced3bbf542335f91d75.

* Revert "More debugging prints"

This reverts commit bb6b8c1313cbdb9f355df20fd203ee02d485042e.

---------

Co-authored-by: Andrew Adams <andrew.b.adams@gmail.com>
1 parent ed357c2
Raw File
find_inverse.cpp
#include <algorithm>
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

int64_t r(int64_t min, int64_t max) {
    int64_t n1 = rand();
    int64_t n2 = rand();
    int64_t n3 = rand();
    n1 = n1 ^ (n2 << 16) ^ (n3 << 16);
    n1 = n1 % (max - min);
    n1 = n1 + min;
    return n1;
}

int64_t sdiv(int64_t a, int64_t b) {
    return (a - ((a % b) + b) % b) / b;
}

int64_t srzdiv(int64_t a, int64_t b) {
    return a / b;
}

bool u_method_0(int den, int sh_post, int bits) {
    uint64_t max = (1L << bits) - 1;
    for (int64_t num = 0; num <= max; num++) {
        uint64_t result = num;
        result >>= sh_post;
        if (num / den != result) return false;
    }
    return true;
}

bool u_method_1(int den, int64_t mul, int sh_post, int bits) {
    uint64_t max = (1L << bits) - 1;
    if (mul > max) return false;
    for (uint64_t num = 0; num <= max; num++) {
        uint64_t result = num;
        result *= mul;
        result >>= bits;
        if (result > max) return false;
        result >>= sh_post;
        if (num / den != result) return false;
    }
    return true;
}

bool u_method_2(int den, int64_t mul, int sh_post, int bits) {
    uint64_t max = (1UL << bits) - 1;
    if (mul > max) return false;
    for (uint64_t num = 0; num <= max; num++) {
        uint64_t result = num;
        result *= mul;
        result >>= bits;
        if (result > max) return false;
        result = (result + num) >> 1;
        if (result > max) return false;
        result >>= sh_post;
        if (num / den != result) return false;
    }
    return true;
}

bool u_method_3(int den, int64_t mul, int sh_post, int bits) {
    uint64_t max = (1UL << bits) - 1;
    if (mul > max) return false;
    for (uint64_t num = 0; num <= max; num++) {
        uint64_t result = num;
        result *= mul;
        result >>= bits;
        if (result > max) return false;
        result = (result + num + 1) >> 1;
        if (result > max) return false;
        result >>= sh_post;
        if (num / den != result) return false;
    }
    return true;
}

bool s_method_0(int den, int sh_post, int bits) {
    int64_t min = -(1L << (bits - 1)), max = (1L << (bits - 1)) - 1;
    for (int64_t num = min; num <= max; num++) {
        int64_t result = num;
        result >>= sh_post;
        if (sdiv(num, den) != result) return false;
    }
    return true;
}

bool s_method_1(int den, int64_t mul, int sh_post, int bits) {
    int64_t min = -(1 << (bits - 1)), max = (1 << (bits - 1)) - 1;

    for (int64_t num = min; num <= max; num++) {
        int64_t result = num;
        uint64_t xsign = result >> (bits - 1);
        uint64_t q0 = (mul * (xsign ^ result)) >> bits;
        result = xsign ^ (q0 >> sh_post);
        if (sdiv(num, den) != result) return false;
    }
    return true;
}

bool srz_method_0(int den, int sh_post, int bits) {
    int64_t min = -(1L << (bits - 1)), max = (1L << (bits - 1)) - 1;
    for (int64_t num = min; num <= max; num++) {
        int64_t result = num;
        result += (result >> (bits - 1)) & (den - 1);
        result >>= sh_post;
        if (srzdiv(num, den) != result) return false;
    }
    return true;
}

bool srz_method_1(int den, int64_t mul, int sh_post, int bits) {
    int64_t min = -(1 << (bits - 1)), max = (1 << (bits - 1)) - 1;

    for (int64_t num = min; num <= max; num++) {
        int64_t result = num;
        uint64_t xsign = result >> (bits - 1);
        uint64_t q0 = (mul * result) >> bits;
        result = (q0 >> sh_post);
        uint64_t mask = (1ULL << bits) - 1;
        result -= (xsign & mask);
        // Fix-up the sign bits
        result <<= (64 - bits);
        result >>= (64 - bits);
        if (srzdiv(num, den) != result) {
            printf("Fail\n");
            return false;
        }
    }
    return true;
}

int main(int argc, char **argv) {
    /* This program computes a table to help us do cheap integer
        division by a constant. It is based on the paper "Division by
        Invariant Integers using Multiplication" by Granlund and
        Montgomery.
    */

    FILE *c_out = fopen("IntegerDivisionTable.cpp", "w");
    FILE *h_out = fopen("IntegerDivisionTable.h", "w");

    fprintf(h_out, "%s",
            "#ifndef HALIDE_INTEGER_DIVISION_TABLE_H\n"
            "#define HALIDE_INTEGER_DIVISION_TABLE_H\n"
            "\n"
            "#include <cstdint>\n"
            "\n"
            "/** \\file\n"
            " * Tables telling us how to do integer division via fixed-point\n"
            " * multiplication for various small constants. This file is \n"
            " * automatically generated by find_inverse.cpp.\n"
            " */\n"
            "namespace Halide {\n"
            "namespace Internal {\n"
            "namespace IntegerDivision {\n");

    fprintf(c_out, "%s",
            "/** \\file\n"
            " * Tables telling us how to do integer division\n"
            " * via fixed-point multiplication for various small\n"
            " * constants. This file is automatically generated\n"
            " * by find_inverse.cpp. There are two sets of tables.\n"
            " * The first set is for compile-time-constant divisors\n"
            " * from 2 to 256. The second is for runtime divisors\n"
            " * from 1 to 255. The second set always uses the most\n"
            " * expensive method, while the compile-time set uses\n"
            " * the cheapest method for the given divisor.\n"
            " */\n"
            "\n"
            "#include \"IntegerDivisionTable.h\"\n"
            "\n"
            "namespace Halide {\n"
            "namespace Internal {\n"
            "namespace IntegerDivision {\n\n");

    for (int runtime = 0; runtime < 2; runtime++) {
        for (int bits = 8; bits <= 32; bits *= 2) {
            printf("Generating table%s_u%d...\n", runtime ? "_runtime" : "", bits);
            if (runtime) {
                fprintf(h_out, "extern const int64_t table_runtime_u%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_runtime_u%d[256][4] = {\n", bits);
            } else {
                fprintf(h_out, "extern const int64_t table_u%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_u%d[256][4] = {\n", bits);
            }
            for (int d = 0; d < 256; d++) {
                if (runtime && d < 2) {
                    fprintf(c_out, "    {0, 0, 0, 0}, // unused\n");
                    continue;
                }

                int den = d;
                if (den == 0) den = 256;
                if (!runtime) {
                    for (int shift = 0; shift < 16; shift++) {
                        if (u_method_0(den, shift, bits)) {
                            fprintf(c_out, "    {%d, 0, 0, %d},\n", den, shift);
                            goto next_unsigned;
                        }
                    }

                    for (int shift = 0; shift < 8; shift++) {
                        int64_t mul = (1L << (bits + shift)) / den + 1;
                        if (u_method_1(den, mul, shift, bits)) {
                            fprintf(c_out, "    {%d, 1, %lldULL, %d},\n", den, (long long)mul, shift);
                            goto next_unsigned;
                        }
                    }

                    for (int shift = 0; shift < 8; shift++) {
                        int64_t mul = (1L << (bits + shift + 1)) / den - (1L << bits);
                        mul &= (1L << bits) - 1;
                        if (u_method_3(den, mul, shift, bits)) {
                            fprintf(c_out, "    {%d, 3, %lldULL, %d},\n", den, (long long)mul, shift);
                            goto next_unsigned;
                        }
                    }
                } else if (d == 1) {
                    // Runtime division by one is handled by a select
                    fprintf(c_out, "    {1, 0, 0ULL, 0},\n");
                    goto next_unsigned;
                }

                {
                    int shift = 31 - __builtin_clz(den - 1);
                    int64_t mul = (1L << (bits + shift + 1)) / den - (1L << bits) + 1;
                    mul &= (1L << bits) - 1;
                    if (u_method_2(den, mul, shift, bits)) {
                        fprintf(c_out, "    {%d, 2, %lldULL, %d},\n", den, (long long)mul, shift);
                        goto next_unsigned;
                    }
                }

                fprintf(c_out, "ERROR! No solution found for unsigned %d\n", den);
            next_unsigned:;
            }
            fprintf(c_out, "};\n");
            printf("Generating table%s_s%d...\n", runtime ? "_runtime" : "", bits);
            if (runtime) {
                fprintf(h_out, "extern const int64_t table_runtime_s%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_runtime_s%d[256][4] = {\n", bits);
            } else {
                fprintf(h_out, "extern const int64_t table_s%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_s%d[256][4] = {\n", bits);
            }
            for (int d = 0; d < 256; d++) {
                if (runtime && d < 2) {
                    fprintf(c_out, "    {0, 0, 0, 0}, // unused\n");
                    continue;
                }
                int den = d;
                if (den == 0) den = 256;
                if (!runtime) {
                    for (int shift = 0; shift < 8; shift++) {
                        if (s_method_0(den, shift, bits)) {
                            fprintf(c_out, "    {%d, 0, 0, %d},\n", den, shift);
                            goto next_signed;
                        }
                    }
                }

                {
                    int shift = 31 - __builtin_clz(den - 1);
                    int64_t mul = (1L << (shift + bits)) / den + 1;
                    if (s_method_1(den, mul, shift, bits)) {
                        fprintf(c_out, "    {%d, 1, %lldLL, %d},\n", den, (long long)mul, shift);
                        goto next_signed;
                    }
                }
                fprintf(c_out, "ERROR! No solution found for signed %d\n", den);
            next_signed:;
            }
            fprintf(c_out, "};\n");
            printf("Generating table%s_srz%d...\n", runtime ? "_runtime" : "", bits);
            if (runtime) {
                fprintf(h_out, "extern const int64_t table_runtime_srz%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_runtime_srz%d[256][4] = {\n", bits);
            } else {
                fprintf(h_out, "extern const int64_t table_srz%d[256][4];\n", bits);
                fprintf(c_out, "const int64_t table_srz%d[256][4] = {\n", bits);
            }
            for (int d = 0; d < 256; d++) {
                if (runtime && d < 2) {
                    fprintf(c_out, "    {0, 0, 0, 0}, // unused\n");
                    continue;
                }
                int den = d;
                if (den == 0) den = 256;
                if (!runtime) {
                    for (int shift = 0; shift < 8; shift++) {
                        if (srz_method_0(den, shift, bits)) {
                            fprintf(c_out, "    {%d, 0, 0, %d},\n", den, shift);
                            goto next_signedrz;
                        }
                    }
                }

                {
                    int shift = 31 - __builtin_clz(den - 1);
                    int64_t mul = (1L << (shift + bits)) / den + 1;
                    if (srz_method_1(den, mul, shift, bits)) {
                        fprintf(c_out, "    {%d, 1, %lldLL, %d},\n", den, (long long)mul, shift);
                        goto next_signedrz;
                    }
                }
                fprintf(c_out, "ERROR! No solution found for signed %d\n", den);
            next_signedrz:;
            }
            fprintf(c_out, "};\n");
        }
    }

    fprintf(h_out, "}\n}\n}\n\n#endif\n");
    fprintf(c_out, "}\n}\n}\n");

    fclose(h_out);
    fclose(c_out);

    return 0;
}
back to top