https://github.com/halide/Halide
Raw File
Tip revision: c0cffcf217b8a6681d2f691443545d311d3cc5c2 authored by Andrew Adams on 07 December 2017, 22:13:44 UTC
Reset branch onto new history
Tip revision: c0cffcf
CodeGen_C.cpp
#include <iostream>
#include <limits>

#include "CodeGen_C.h"
#include "CodeGen_Internal.h"
#include "Substitute.h"
#include "IROperator.h"
#include "Param.h"
#include "Var.h"
#include "Lerp.h"
#include "Simplify.h"

namespace Halide {
namespace Internal {

using std::ostream;
using std::endl;
using std::string;
using std::vector;
using std::ostringstream;
using std::map;

namespace {
const string buffer_t_definition =
    "#ifndef HALIDE_ATTRIBUTE_ALIGN\n"
    "  #ifdef _MSC_VER\n"
    "    #define HALIDE_ATTRIBUTE_ALIGN(x) __declspec(align(x))\n"
    "  #else\n"
    "    #define HALIDE_ATTRIBUTE_ALIGN(x) __attribute__((aligned(x)))\n"
    "  #endif\n"
    "#endif\n"
    "#ifndef BUFFER_T_DEFINED\n"
    "#define BUFFER_T_DEFINED\n"
    "#include <stdbool.h>\n"
    "#include <stdint.h>\n"
    "typedef struct buffer_t {\n"
    "    uint64_t dev;\n"
    "    uint8_t* host;\n"
    "    int32_t extent[4];\n"
    "    int32_t stride[4];\n"
    "    int32_t min[4];\n"
    "    int32_t elem_size;\n"
    "    HALIDE_ATTRIBUTE_ALIGN(1) bool host_dirty;\n"
    "    HALIDE_ATTRIBUTE_ALIGN(1) bool dev_dirty;\n"
    "    HALIDE_ATTRIBUTE_ALIGN(1) uint8_t _padding[10 - sizeof(void *)];\n"
    "} buffer_t;\n"
    "#endif\n";

const string headers =
    "#include <iostream>\n"
    "#include <math.h>\n"
    "#include <float.h>\n"
    "#include <assert.h>\n"
    "#include <string.h>\n"
    "#include <stdio.h>\n"
    "#include <stdint.h>\n";

const string globals =
    "extern \"C\" {\n"
    "void *halide_malloc(void *ctx, size_t);\n"
    "void halide_free(void *ctx, void *ptr);\n"
    "void *halide_print(void *ctx, const void *str);\n"
    "void *halide_error(void *ctx, const void *str);\n"
    "int halide_debug_to_file(void *ctx, const char *filename, int, struct buffer_t *buf);\n"
    "int halide_start_clock(void *ctx);\n"
    "int64_t halide_current_time_ns(void *ctx);\n"
    "void halide_profiler_pipeline_end(void *, void *);\n"
    "}\n"
    "\n"

    // TODO: this next chunk is copy-pasted from posix_math.cpp. A
    // better solution for the C runtime would be nice.
    "#ifdef _WIN32\n"
    "float roundf(float);\n"
    "double round(double);\n"
    "#else\n"
    "inline float asinh_f32(float x) {return asinhf(x);}\n"
    "inline float acosh_f32(float x) {return acoshf(x);}\n"
    "inline float atanh_f32(float x) {return atanhf(x);}\n"
    "inline double asinh_f64(double x) {return asinh(x);}\n"
    "inline double acosh_f64(double x) {return acosh(x);}\n"
    "inline double atanh_f64(double x) {return atanh(x);}\n"
    "#endif\n"
    "inline float sqrt_f32(float x) {return sqrtf(x);}\n"
    "inline float sin_f32(float x) {return sinf(x);}\n"
    "inline float asin_f32(float x) {return asinf(x);}\n"
    "inline float cos_f32(float x) {return cosf(x);}\n"
    "inline float acos_f32(float x) {return acosf(x);}\n"
    "inline float tan_f32(float x) {return tanf(x);}\n"
    "inline float atan_f32(float x) {return atanf(x);}\n"
    "inline float sinh_f32(float x) {return sinhf(x);}\n"
    "inline float cosh_f32(float x) {return coshf(x);}\n"
    "inline float tanh_f32(float x) {return tanhf(x);}\n"
    "inline float hypot_f32(float x, float y) {return hypotf(x, y);}\n"
    "inline float exp_f32(float x) {return expf(x);}\n"
    "inline float log_f32(float x) {return logf(x);}\n"
    "inline float pow_f32(float x, float y) {return powf(x, y);}\n"
    "inline float floor_f32(float x) {return floorf(x);}\n"
    "inline float ceil_f32(float x) {return ceilf(x);}\n"
    "inline float round_f32(float x) {return roundf(x);}\n"
    "\n"
    "inline double sqrt_f64(double x) {return sqrt(x);}\n"
    "inline double sin_f64(double x) {return sin(x);}\n"
    "inline double asin_f64(double x) {return asin(x);}\n"
    "inline double cos_f64(double x) {return cos(x);}\n"
    "inline double acos_f64(double x) {return acos(x);}\n"
    "inline double tan_f64(double x) {return tan(x);}\n"
    "inline double atan_f64(double x) {return atan(x);}\n"
    "inline double sinh_f64(double x) {return sinh(x);}\n"
    "inline double cosh_f64(double x) {return cosh(x);}\n"
    "inline double tanh_f64(double x) {return tanh(x);}\n"
    "inline double hypot_f64(double x, double y) {return hypot(x, y);}\n"
    "inline double exp_f64(double x) {return exp(x);}\n"
    "inline double log_f64(double x) {return log(x);}\n"
    "inline double pow_f64(double x, double y) {return pow(x, y);}\n"
    "inline double floor_f64(double x) {return floor(x);}\n"
    "inline double ceil_f64(double x) {return ceil(x);}\n"
    "inline double round_f64(double x) {return round(x);}\n"
    "\n"
    "inline float nan_f32() {return NAN;}\n"
    "inline float neg_inf_f32() {return -INFINITY;}\n"
    "inline float inf_f32() {return INFINITY;}\n"
    "inline bool is_nan_f32(float x) {return x != x;}\n"
    "inline bool is_nan_f64(double x) {return x != x;}\n"
    "inline float float_from_bits(uint32_t bits) {\n"
    " union {\n"
    "  uint32_t as_uint;\n"
    "  float as_float;\n"
    " } u;\n"
    " u.as_uint = bits;\n"
    " return u.as_float;\n"
    "}\n"
    "inline int64_t make_int64(int32_t hi, int32_t lo) {\n"
    "    return (((int64_t)hi) << 32) | (uint32_t)lo;\n"
    "}\n"
    "inline double make_float64(int32_t i0, int32_t i1) {\n"
    "    union {\n"
    "        int32_t as_int32[2];\n"
    "        double as_double;\n"
    "    } u;\n"
    "    u.as_int32[0] = i0;\n"
    "    u.as_int32[1] = i1;\n"
    "    return u.as_double;\n"
    "}\n"
    "\n"
    "template<typename T> T max(T a, T b) {if (a > b) return a; return b;}\n"
    "template<typename T> T min(T a, T b) {if (a < b) return a; return b;}\n"

    // This may look wasteful, but it's the right way to do
    // it. Compilers understand memcpy and will convert it to a no-op
    // when used in this way. See http://blog.regehr.org/archives/959
    // for a detailed comparison of type-punning methods.
    "template<typename A, typename B> A reinterpret(B b) {A a; memcpy(&a, &b, sizeof(a)); return a;}\n"
    "\n"
    "static bool halide_rewrite_buffer(buffer_t *b, int32_t elem_size,\n"
    "                           int32_t min0, int32_t extent0, int32_t stride0,\n"
    "                           int32_t min1, int32_t extent1, int32_t stride1,\n"
    "                           int32_t min2, int32_t extent2, int32_t stride2,\n"
    "                           int32_t min3, int32_t extent3, int32_t stride3) {\n"
    " b->min[0] = min0;\n"
    " b->min[1] = min1;\n"
    " b->min[2] = min2;\n"
    " b->min[3] = min3;\n"
    " b->extent[0] = extent0;\n"
    " b->extent[1] = extent1;\n"
    " b->extent[2] = extent2;\n"
    " b->extent[3] = extent3;\n"
    " b->stride[0] = stride0;\n"
    " b->stride[1] = stride1;\n"
    " b->stride[2] = stride2;\n"
    " b->stride[3] = stride3;\n"
    " return true;\n"
    "}\n";
}

CodeGen_C::CodeGen_C(ostream &s, OutputKind output_kind, const std::string &guard) : IRPrinter(s), id("$$ BAD ID $$"), output_kind(output_kind) {
    if (is_header()) {
        // If it's a header, emit an include guard.
        stream << "#ifndef HALIDE_" << print_name(guard) << '\n'
               << "#define HALIDE_" << print_name(guard) << '\n';
    }

    if (!is_header()) {
        stream << headers;
    }

    // Throw in a definition of a buffer_t
    stream << buffer_t_definition;

    // halide_filter_metadata_t just gets a forward declaration
    // (include HalideRuntime.h for the full goodness)
    stream << "struct halide_filter_metadata_t;\n";

    if (!is_header()) {
        stream << globals;
    }

    // Throw in a default (empty) definition of HALIDE_FUNCTION_ATTRS
    // (some hosts may define this to e.g. __attribute__((warn_unused_result)))
    stream << "#ifndef HALIDE_FUNCTION_ATTRS\n";
    stream << "#define HALIDE_FUNCTION_ATTRS\n";
    stream << "#endif\n";

    if (!is_c_plus_plus_interface()) {
        // Everything from here on out is extern "C".
        stream << "#ifdef __cplusplus\n";
        stream << "extern \"C\" {\n";
        stream << "#endif\n";
    }
}

CodeGen_C::~CodeGen_C() {
    if (!is_c_plus_plus_interface()) {
        stream << "#ifdef __cplusplus\n";
        stream << "}  // extern \"C\"\n";
        stream << "#endif\n";
    }

    if (is_header()) {
        stream << "#endif\n";
    }
}

namespace {
string type_to_c_type(Type type, bool include_space, bool c_plus_plus = true) {
    bool needs_space = true;
    ostringstream oss;
    user_assert(type.lanes() == 1) << "Can't use vector types when compiling to C (yet)\n";
    if (type.is_float()) {
        if (type.bits() == 32) {
            oss << "float";
        } else if (type.bits() == 64) {
            oss << "double";
        } else {
            user_error << "Can't represent a float with this many bits in C: " << type << "\n";
        }

    } else if (type.is_handle()) {
        needs_space = false;

        // If there is no type info or is generating C (not C++) and
        // the type is a class or in an inner scope, just use void *.
        if (type.handle_type == NULL ||
            (!c_plus_plus &&
             (!type.handle_type->namespaces.empty() ||
              !type.handle_type->enclosing_types.empty() ||
              type.handle_type->inner_name.cpp_type_type == halide_cplusplus_type_name::Class))) {
            oss << "const void *";
        } else {
            if (type.handle_type->inner_name.cpp_type_type == halide_cplusplus_type_name::Struct) {
                oss << "struct ";
            } else if (type.handle_type->inner_name.cpp_type_type == halide_cplusplus_type_name::Class) {
                oss << "class ";
            }
            if (!type.handle_type->namespaces.empty() ||
                !type.handle_type->enclosing_types.empty()) {
                oss << "::";
                for (size_t i = 0; i < type.handle_type->namespaces.size(); i++) {
                    oss << type.handle_type->namespaces[i] << "::";
                }
                for (size_t i = 0; i < type.handle_type->enclosing_types.size(); i++) {
                    oss << type.handle_type->enclosing_types[i].name << "::";
                }
            }
            oss << type.handle_type->inner_name.name;
            if (type.handle_type->reference_type == halide_handle_cplusplus_type::LValueReference) {
                oss << " &";
            } else if (type.handle_type->reference_type == halide_handle_cplusplus_type::LValueReference) {
                oss << " &&";
            }
            for (auto modifier : type.handle_type->cpp_type_modifiers) {
                if (modifier & halide_handle_cplusplus_type::Const) {
                    oss << " const";
                }
                if (modifier & halide_handle_cplusplus_type::Volatile) {
                    oss << " volatile";
                }
                if (modifier & halide_handle_cplusplus_type::Restrict) {
                    oss << " restrict";
                }
                if (modifier & halide_handle_cplusplus_type::Pointer) {
                    oss << " *";
                } else {
                    break;
                }
              
            }
        }
    } else {
        switch (type.bits()) {
        case 1:
            oss << "bool";
            break;
        case 8: case 16: case 32: case 64:
            if (type.is_uint()) oss << 'u';
            oss << "int" << type.bits() << "_t";
            break;
        default:
            user_error << "Can't represent an integer with this many bits in C: " << type << "\n";
        }
    }
    if (include_space && needs_space)
        oss << " ";
    return oss.str();
}
}

string CodeGen_C::print_type(Type type, AppendSpaceIfNeeded space_option) {
    return type_to_c_type(type, space_option == AppendSpace);
}

string CodeGen_C::print_reinterpret(Type type, Expr e) {
    ostringstream oss;
    oss << "reinterpret<" << print_type(type) << ">(" << print_expr(e) << ")";
    return oss.str();
}

string CodeGen_C::print_name(const string &name) {
    ostringstream oss;

    // Prefix an underscore to avoid reserved words (e.g. a variable named "while")
    if (isalpha(name[0])) {
        oss << '_';
    }

    for (size_t i = 0; i < name.size(); i++) {
        if (name[i] == '.') {
            oss << '_';
        } else if (name[i] == '$') {
            oss << "__";
        } else if (name[i] != '_' && !isalnum(name[i])) {
            oss << "___";
        }
        else oss << name[i];
    }
    return oss.str();
}

namespace {
class ExternCallPrototypes : public IRGraphVisitor {
    ostream &stream;
    std::set<string> &emitted;
    using IRGraphVisitor::visit;
    // TODO: This class should likely be able to signal an error if C++
    // code shows up and started_in_c_plus_plus isn't true, but the logic
    // is orthogonal.
    const bool started_in_c_plus_plus;
    bool in_c_plus_plus;

    void switch_calling_convention(bool c_plus_plus) {
      if (in_c_plus_plus != c_plus_plus) {
          if (in_c_plus_plus) {
              stream << "}\n";
          } else {
              stream << "extern \"C\" {\n";
          }
          in_c_plus_plus = c_plus_plus;
      }
    }

    void visit(const Call *op) {
        IRGraphVisitor::visit(op);

        if (op->call_type == Call::Extern ||
            op->call_type == Call::ExternCPlusPlus) {
            switch_calling_convention(op->call_type == Call::ExternCPlusPlus);
            // TODO: optimize generation of namespacing to reuse namespace decls.
            int32_t namespace_count = 0;
            std::string name;
            if (op->call_type == Call::ExternCPlusPlus) {
                std::vector<std::string> namespaces;
                name = extract_namespaces(op->name, namespaces);
                for (auto const &ns : namespaces) {
                    stream << "namespace " << ns << " { ";
                }
                namespace_count = namespaces.size();
            } else {
                name = op->name;
            }

            if (!emitted.count(name)) {
                stream << type_to_c_type(op->type, true) << " " << name << "(";
                if (function_takes_user_context(name)) {
                    stream << "void *";
                    if (op->args.size()) {
                        stream << ", ";
                    }
                }
                for (size_t i = 0; i < op->args.size(); i++) {
                    if (i > 0) {
                        stream << ", ";
                    }
                    if (op->args[i].as<StringImm>()) {
                        stream << "const char *";
                    } else {
                      stream << type_to_c_type(op->args[i].type(), true);
                    }
                }
                stream << ");";
                for (int32_t i = 0; i < namespace_count; i++) {
                    stream << " }";
                }
                stream << "\n";
                emitted.insert(op->name); // Keep namespacing here.
            }
        }
    }

public:
  ExternCallPrototypes(ostream &s, std::set<string> &emitted, bool in_c_plus_plus)
      : stream(s), emitted(emitted), started_in_c_plus_plus(in_c_plus_plus), in_c_plus_plus(in_c_plus_plus) {
        size_t j = 0;
        // Make sure we don't catch calls that are already in the global declarations
        for (size_t i = 0; i < globals.size(); i++) {
            char c = globals[i];
            if (c == '(' && i > j+1) {
                // Could be the end of a function_name.
                emitted.insert(globals.substr(j+1, i-j-1));
            }

            if (('A' <= c && c <= 'Z') ||
                ('a' <= c && c <= 'z') ||
                c == '_' ||
                ('0' <= c && c <= '9')) {
                // Could be part of a function name.
            } else {
                j = i;
            }

        }
    }

  ~ExternCallPrototypes() {
      switch_calling_convention(started_in_c_plus_plus);
  }
};
}

void CodeGen_C::compile(const Module &input) {
    for (size_t i = 0; i < input.buffers.size(); i++) {
        compile(input.buffers[i]);
    }
    for (size_t i = 0; i < input.functions.size(); i++) {
        compile(input.functions[i]);
    }
}

void CodeGen_C::compile(const LoweredFunc &f) {
    // Don't put non-external function declarations in headers.
    if (is_header() && f.linkage != LoweredFunc::External) {
        return;
    }

    internal_assert(emitted.count(f.name) == 0)
        << "Function '" << f.name << "'  has already been emitted.\n";
    emitted.insert(f.name);

    const std::vector<Argument> &args = f.args;

    for (size_t i = 0; i < args.size(); i++) {
        if (args[i].type.handle_type != NULL) {
            if (!args[i].type.handle_type->namespaces.empty()) {
                if (args[i].type.handle_type->inner_name.cpp_type_type != halide_cplusplus_type_name::Simple) {
                    for (size_t ns = 0; ns < args[i].type.handle_type->namespaces.size(); ns++ ) {
                        for (size_t indent = 0; indent < ns; indent++) {
                           stream << "    ";
                        }
                        stream << indent << "namespace " << args[i].type.handle_type->namespaces[ns] << " {\n";
                    }
                    for (size_t indent = 0; indent < args[i].type.handle_type->namespaces.size(); indent++) {
                        stream << "    ";
                    }
                    if (args[i].type.handle_type->inner_name.cpp_type_type != halide_cplusplus_type_name::Struct) {
                        stream << "struct " << args[i].type.handle_type->inner_name.name << ";\n";
                    } else {
                        stream << "class " << args[i].type.handle_type->inner_name.name << ";\n";
                    }
                    for (size_t ns = 0; ns < args[i].type.handle_type->namespaces.size(); ns++ ) {
                        for (size_t indent = 0; indent < ns; indent++) {
                           stream << "    ";
                        }
                        stream << indent << "}\n";
                    }
                }
            }
        }           
    }

    have_user_context = false;
    for (size_t i = 0; i < args.size(); i++) {
        // TODO: check that its type is void *?
        have_user_context |= (args[i].name == "__user_context");
    }

    // Emit prototypes for any extern calls used.
    if (!is_header()) {
        stream << "\n";
        ExternCallPrototypes e(stream, emitted, is_c_plus_plus_interface());
        f.body.accept(&e);
        stream << "\n";
    }

    std::vector<std::string> namespaces;
    std::string simple_name = extract_namespaces(f.name, namespaces);
    if (!is_c_plus_plus_interface()) {
        user_assert(namespaces.empty()) <<
            "Namespace qualifiers not allowed on function name if not compiling with Target::CPlusPlusNameMangling.\n";
    }

    if (!namespaces.empty()) {
        const char *separator = "";
        for (const auto &ns : namespaces) {
            stream << separator << "namespace " << ns << " {";
            separator = " ";
        }
        stream << "\n\n";
    }

    // Emit the function prototype
    if (f.linkage != LoweredFunc::External) {
        // If the function isn't public, mark it static.
        stream << "static ";
    }
    stream << "int " << simple_name << "(";
    for (size_t i = 0; i < args.size(); i++) {
        if (args[i].is_buffer()) {
            stream << "buffer_t *"
                   << print_name(args[i].name)
                   << "_buffer";
        } else {
            stream << print_type(args[i].type, AppendSpace)
                   << print_name(args[i].name);
        }

        if (i < args.size()-1) stream << ", ";
    }

    if (is_header()) {
        stream << ") HALIDE_FUNCTION_ATTRS;\n";
    } else {
        stream << ") HALIDE_FUNCTION_ATTRS {\n";
        indent += 1;

        // Unpack the buffer_t's
        for (size_t i = 0; i < args.size(); i++) {
            if (args[i].is_buffer()) {
                push_buffer(args[i].type, args[i].name);
            }
        }
        // Emit the body
        print(f.body);

        // Return success.
        do_indent();
        stream << "return 0;\n";

        indent -= 1;
        stream << "}\n";

        // Done with the buffer_t's, pop the associated symbols.
        for (size_t i = 0; i < args.size(); i++) {
            if (args[i].is_buffer()) {
                pop_buffer(args[i].name);
            }
        }
    }

    if (is_header()) {
        // If this is a header and we are here, we know this is an externally visible Func, so
        // declare the argv function.
        stream << "int " << simple_name << "_argv(void **args) HALIDE_FUNCTION_ATTRS;\n";
    }

    // Close namespaces here as metadata must be outside them
    if (!namespaces.empty()) {
        stream << "\n";
        for (size_t i = 0; i < namespaces.size(); i++) {
            stream << "}";
        }
        stream << " // Close namespaces ";
        const char *separator = "";
        for (const auto &ns : namespaces) {
            stream << separator << ns;
            separator = "::";
        }

        stream << "\n\n";
    }

    if (is_header()) {
        // And also the metadata.
       stream << "extern const struct halide_filter_metadata_t " << simple_name << "_metadata;\n";
    }
}

void CodeGen_C::compile(const Buffer &buffer) {
    // Don't define buffers in headers.
    if (is_header()) {
        return;
    }

    string name = print_name(buffer.name());
    buffer_t b = *(buffer.raw_buffer());

    // Figure out the offset of the last pixel.
    size_t num_elems = 1;
    for (int d = 0; b.extent[d]; d++) {
        num_elems += b.stride[d] * (b.extent[d] - 1);
    }

    // Emit the data
    stream << "static uint8_t " << name << "_data[] __attribute__ ((aligned (32))) = {";
    for (size_t i = 0; i < num_elems * b.elem_size; i++) {
        if (i > 0) stream << ", ";
        stream << (int)(b.host[i]);
    }
    stream << "};\n";

    // Emit the buffer_t
    user_assert(b.host) << "Can't embed image: " << buffer.name() << " because it has a null host pointer\n";
    user_assert(!b.dev_dirty) << "Can't embed image: " << buffer.name() << "because it has a dirty device pointer\n";
    stream << "static buffer_t " << name << "_buffer = {"
           << "0, " // dev
           << "&" << name << "_data[0], " // host
           << "{" << b.extent[0] << ", " << b.extent[1] << ", " << b.extent[2] << ", " << b.extent[3] << "}, "
           << "{" << b.stride[0] << ", " << b.stride[1] << ", " << b.stride[2] << ", " << b.stride[3] << "}, "
           << "{" << b.min[0] << ", " << b.min[1] << ", " << b.min[2] << ", " << b.min[3] << "}, "
           << b.elem_size << ", "
           << "0, " // host_dirty
           << "0};\n"; //dev_dirty

    // Make a global pointer to it
    stream << "static buffer_t *" << name << " = &" << name << "_buffer;\n";
}

void CodeGen_C::push_buffer(Type t, const std::string &buffer_name) {
    string name = print_name(buffer_name);
    string buf_name = name + "_buffer";
    string type = print_type(t);
    do_indent();
    stream << type
           << " *"
           << name
           << " = ("
           << type
           << " *)("
           << buf_name
           << "->host);\n";
    Allocation alloc;
    alloc.type = t;
    allocations.push(buffer_name, alloc);
    do_indent();
    stream << "(void)" << name << ";\n";

    do_indent();
    stream << "const bool "
           << name
           << "_host_and_dev_are_null = ("
           << buf_name << "->host == nullptr) && ("
           << buf_name << "->dev == 0);\n";
    do_indent();
    stream << "(void)" << name << "_host_and_dev_are_null;\n";

    for (int j = 0; j < 4; j++) {
        do_indent();
        stream << "const int32_t "
               << name
               << "_min_" << j << " = "
               << buf_name
               << "->min[" << j << "];\n";
        // emit a void cast to suppress "unused variable" warnings
        do_indent();
        stream << "(void)" << name << "_min_" << j << ";\n";
    }
    for (int j = 0; j < 4; j++) {
        do_indent();
        stream << "const int32_t "
               << name
               << "_extent_" << j << " = "
               << buf_name
               << "->extent[" << j << "];\n";
        do_indent();
        stream << "(void)" << name << "_extent_" << j << ";\n";
    }
    for (int j = 0; j < 4; j++) {
        do_indent();
        stream << "const int32_t "
               << name
               << "_stride_" << j << " = "
               << buf_name
               << "->stride[" << j << "];\n";
        do_indent();
        stream << "(void)" << name << "_stride_" << j << ";\n";
    }
    do_indent();
    stream << "const int32_t "
           << name
           << "_elem_size = "
           << buf_name
           << "->elem_size;\n";
    do_indent();
    stream << "(void)" << name << "_elem_size;\n";
}

void CodeGen_C::pop_buffer(const std::string &buffer_name) {
    allocations.pop(buffer_name);
}

string CodeGen_C::print_expr(Expr e) {
    id = "$$ BAD ID $$";
    e.accept(this);
    return id;
}

void CodeGen_C::print_stmt(Stmt s) {
    s.accept(this);
}

string CodeGen_C::print_assignment(Type t, const std::string &rhs) {

    map<string, string>::iterator cached = cache.find(rhs);

    if (cached == cache.end()) {
        id = unique_name('_');
        do_indent();
        stream << print_type(t, AppendSpace) << id << " = " << rhs << ";\n";
        cache[rhs] = id;
    } else {
        id = cached->second;
    }
    return id;
}

void CodeGen_C::open_scope() {
    cache.clear();
    do_indent();
    indent++;
    stream << "{\n";
}

void CodeGen_C::close_scope(const std::string &comment) {
    cache.clear();
    indent--;
    do_indent();
    if (!comment.empty()) {
        stream << "} // " << comment << "\n";
    } else {
        stream << "}\n";
    }
}

void CodeGen_C::visit(const Variable *op) {
    id = print_name(op->name);
}

void CodeGen_C::visit(const Cast *op) {
    print_assignment(op->type, "(" + print_type(op->type) + ")(" + print_expr(op->value) + ")");
}

void CodeGen_C::visit_binop(Type t, Expr a, Expr b, const char * op) {
    string sa = print_expr(a);
    string sb = print_expr(b);
    print_assignment(t, sa + " " + op + " " + sb);
}

void CodeGen_C::visit(const Add *op) {
    visit_binop(op->type, op->a, op->b, "+");
}

void CodeGen_C::visit(const Sub *op) {
    visit_binop(op->type, op->a, op->b, "-");
}

void CodeGen_C::visit(const Mul *op) {
    visit_binop(op->type, op->a, op->b, "*");
}

void CodeGen_C::visit(const Div *op) {
    int bits;
    if (is_const_power_of_two_integer(op->b, &bits)) {
        ostringstream oss;
        oss << print_expr(op->a) << " >> " << bits;
        print_assignment(op->type, oss.str());
    } else if (op->type.is_int()) {
        string a = print_expr(op->a);
        string b = print_expr(op->b);
        // q = a / b
        string q = print_assignment(op->type, a + " / " + b);
        // r = a - q * b
        string r = print_assignment(op->type, a + " - " + q + " * " + b);
        // bs = b >> (8*sizeof(T) - 1)
        string bs = print_assignment(op->type, b + " >> (" + print_type(op->type.element_of()) + ")" + std::to_string(op->type.bits() - 1));
        // rs = r >> (8*sizeof(T) - 1)
        string rs = print_assignment(op->type, r + " >> (" + print_type(op->type.element_of()) + ")" + std::to_string(op->type.bits() - 1));
        // id = q - (rs & bs) + (rs & bs)
        print_assignment(op->type, q + " - (" + rs + " & " + bs + ") + (" + rs + " & ~" + bs + ")");
    } else {
        visit_binop(op->type, op->a, op->b, "/");
    }
}

void CodeGen_C::visit(const Mod *op) {
    int bits;
    if (is_const_power_of_two_integer(op->b, &bits)) {
        ostringstream oss;
        oss << print_expr(op->a) << " & " << ((1 << bits)-1);
        print_assignment(op->type, oss.str());
    } else if (op->type.is_int()) {
        string a = print_expr(op->a);
        string b = print_expr(op->b);
        // r = a % b
        string r = print_assignment(op->type, a + " % " + b);
        // rs = r >> (8*sizeof(T) - 1)
        string rs = print_assignment(op->type, r + " >> (" + print_type(op->type.element_of()) + ")" + std::to_string(op->type.bits() - 1));
        // abs_b = abs(b)
        string abs_b = print_expr(cast(op->type, abs(op->b)));
        // id = r + (abs_b & rs)
        print_assignment(op->type, r + " + (" + abs_b + " & " + rs + ")");
    } else {
        visit_binop(op->type, op->a, op->b, "%");
    }
}

void CodeGen_C::visit(const Max *op) {
    print_expr(Call::make(op->type, "max", {op->a, op->b}, Call::Extern));
}

void CodeGen_C::visit(const Min *op) {
    print_expr(Call::make(op->type, "min", {op->a, op->b}, Call::Extern));
}

void CodeGen_C::visit(const EQ *op) {
    visit_binop(op->type, op->a, op->b, "==");
}

void CodeGen_C::visit(const NE *op) {
    visit_binop(op->type, op->a, op->b, "!=");
}

void CodeGen_C::visit(const LT *op) {
    visit_binop(op->type, op->a, op->b, "<");
}

void CodeGen_C::visit(const LE *op) {
    visit_binop(op->type, op->a, op->b, "<=");
}

void CodeGen_C::visit(const GT *op) {
    visit_binop(op->type, op->a, op->b, ">");
}

void CodeGen_C::visit(const GE *op) {
    visit_binop(op->type, op->a, op->b, ">=");
}

void CodeGen_C::visit(const Or *op) {
    visit_binop(op->type, op->a, op->b, "||");
}

void CodeGen_C::visit(const And *op) {
    visit_binop(op->type, op->a, op->b, "&&");
}

void CodeGen_C::visit(const Not *op) {
    print_assignment(op->type, "!(" + print_expr(op->a) + ")");
}

void CodeGen_C::visit(const IntImm *op) {
    if (op->type == Int(32)) {
        id = std::to_string(op->value);
    } else {
        print_assignment(op->type, "(" + print_type(op->type) + ")(" + std::to_string(op->value) + ")");
    }
}

void CodeGen_C::visit(const UIntImm *op) {
    print_assignment(op->type, "(" + print_type(op->type) + ")(" + std::to_string(op->value) + ")");
}

void CodeGen_C::visit(const StringImm *op) {
    ostringstream oss;
    oss << Expr(op);
    id = oss.str();
}

// NaN is the only float/double for which this is true... and
// surprisingly, there doesn't seem to be a portable isnan function
// (dsharlet).
template <typename T>
static bool isnan(T x) { return x != x; }

template <typename T>
static bool isinf(T x)
{
    return std::numeric_limits<T>::has_infinity && (
        x == std::numeric_limits<T>::infinity() ||
        x == -std::numeric_limits<T>::infinity());
}

void CodeGen_C::visit(const FloatImm *op) {
    if (isnan(op->value)) {
        id = "nan_f32()";
    } else if (isinf(op->value)) {
        if (op->value > 0) {
            id = "inf_f32()";
        } else {
            id = "neg_inf_f32()";
        }
    } else {
        // Write the constant as reinterpreted uint to avoid any bits lost in conversion.
        union {
            uint32_t as_uint;
            float as_float;
        } u;
        u.as_float = op->value;

        ostringstream oss;
        oss << "float_from_bits(" << u.as_uint << " /* " << u.as_float << " */)";
        id = oss.str();
    }
}

void CodeGen_C::visit(const Call *op) {

    internal_assert(op->call_type == Call::Extern ||
                    op->call_type == Call::ExternCPlusPlus ||
                    op->call_type == Call::PureExtern ||
                    op->call_type == Call::Intrinsic ||
                    op->call_type == Call::PureIntrinsic)
        << "Can only codegen extern calls and intrinsics\n";

    ostringstream rhs;

    // Handle intrinsics first
    if (op->is_intrinsic(Call::debug_to_file)) {
        internal_assert(op->args.size() == 3);
        const StringImm *string_imm = op->args[0].as<StringImm>();
        internal_assert(string_imm);
        string filename = string_imm->value;
        string typecode = print_expr(op->args[1]);
        string buffer = print_name(print_expr(op->args[2]));

        rhs << "halide_debug_to_file(";
        rhs << (have_user_context ? "__user_context_" : "nullptr");
        rhs << ", \"" + filename + "\", " + typecode;
        rhs << ", (struct buffer_t *)" << buffer;
        rhs << ")";
    } else if (op->is_intrinsic(Call::bitwise_and)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        rhs << a0 << " & " << a1;
    } else if (op->is_intrinsic(Call::bitwise_xor)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        rhs << a0 << " ^ " << a1;
    } else if (op->is_intrinsic(Call::bitwise_or)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        rhs << a0 << " | " << a1;
    } else if (op->is_intrinsic(Call::bitwise_not)) {
        internal_assert(op->args.size() == 1);
        rhs << "~" << print_expr(op->args[0]);
    } else if (op->is_intrinsic(Call::reinterpret)) {
        internal_assert(op->args.size() == 1);
        rhs << print_reinterpret(op->type, op->args[0]);
    } else if (op->is_intrinsic(Call::shift_left)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        rhs << a0 << " << " << a1;
    } else if (op->is_intrinsic(Call::shift_right)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        rhs << a0 << " >> " << a1;
    } else if (op->is_intrinsic(Call::rewrite_buffer)) {
        int dims = ((int)(op->args.size())-2)/3;
        (void)dims; // In case internal_assert is ifdef'd to do nothing
        internal_assert((int)(op->args.size()) == dims*3 + 2);
        internal_assert(dims <= 4);
        vector<string> args(op->args.size());
        const Variable *v = op->args[0].as<Variable>();
        internal_assert(v);
        args[0] = print_name(v->name);
        for (size_t i = 1; i < op->args.size(); i++) {
            args[i] = print_expr(op->args[i]);
        }
        rhs << "halide_rewrite_buffer(";
        for (size_t i = 0; i < 14; i++) {
            if (i > 0) rhs << ", ";
            if (i < args.size()) {
                rhs << args[i];
            } else {
                rhs << '0';
            }
        }
        rhs << ")";
    } else if (op->is_intrinsic(Call::lerp)) {
        internal_assert(op->args.size() == 3);
        Expr e = lower_lerp(op->args[0], op->args[1], op->args[2]);
        rhs << print_expr(e);
    } else if (op->is_intrinsic(Call::absd)) {
        internal_assert(op->args.size() == 2);
        Expr a = op->args[0];
        Expr b = op->args[1];
        Expr e = select(a < b, b - a, a - b);
        rhs << print_expr(e);
    } else if (op->is_intrinsic(Call::null_handle)) {
        rhs << "nullptr";
    } else if (op->is_intrinsic(Call::address_of)) {
        const Load *l = op->args[0].as<Load>();
        internal_assert(op->args.size() == 1 && l);
        rhs << "(("
            << print_type(l->type.element_of()) // index is in elements, not vectors.
            << " *)"
            << print_name(l->name)
            << " + "
            << print_expr(l->index)
            << ")";
    } else if (op->is_intrinsic(Call::return_second)) {
        internal_assert(op->args.size() == 2);
        string arg0 = print_expr(op->args[0]);
        string arg1 = print_expr(op->args[1]);
        rhs << "(" << arg0 << ", " << arg1 << ")";
    } else if (op->is_intrinsic(Call::if_then_else)) {
        internal_assert(op->args.size() == 3);

        string result_id = unique_name('_');

        do_indent();
        stream << print_type(op->args[1].type(), AppendSpace)
               << result_id << ";\n";

        string cond_id = print_expr(op->args[0]);

        do_indent();
        stream << "if (" << cond_id << ")\n";
        open_scope();
        string true_case = print_expr(op->args[1]);
        do_indent();
        stream << result_id << " = " << true_case << ";\n";
        close_scope("if " + cond_id);
        do_indent();
        stream << "else\n";
        open_scope();
        string false_case = print_expr(op->args[2]);
        do_indent();
        stream << result_id << " = " << false_case << ";\n";
        close_scope("if " + cond_id + " else");

        rhs << result_id;
    } else if (op->is_intrinsic(Call::copy_buffer_t)) {
        internal_assert(op->args.size() == 1);
        string arg = print_expr(op->args[0]);
        string buf_id = unique_name('B');
        stream << "buffer_t " << buf_id << " = *((buffer_t *)(" << arg << "))\n";
        rhs << "(&" << buf_id << ")";
    } else if (op->is_intrinsic(Call::create_buffer_t)) {
        internal_assert(op->args.size() >= 2);
        vector<string> args;
        args.push_back(print_expr(op->args[0]));
        args.push_back(print_expr(op->args[1].type().bytes()));
        for (size_t i = 2; i < op->args.size(); i++) {
            args.push_back(print_expr(op->args[i]));
        }
        string buf_id = unique_name('B');
        do_indent();
        stream << "buffer_t " << buf_id << " = {0};\n";
        do_indent();
        stream << buf_id << ".host = const_cast<uint8_t *>((const uint8_t *)(" << args[0] << "));\n";
        do_indent();
        stream << buf_id << ".elem_size = " << args[1] << ";\n";
        int dims = ((int)op->args.size() - 2)/3;
        for (int i = 0; i < dims; i++) {
            do_indent();
            stream << buf_id << ".min[" << i << "] = " << args[i*3+2] << ";\n";
            do_indent();
            stream << buf_id << ".extent[" << i << "] = " << args[i*3+3] << ";\n";
            do_indent();
            stream << buf_id << ".stride[" << i << "] = " << args[i*3+4] << ";\n";
        }
        rhs << "(&" + buf_id + ")";
    } else if (op->is_intrinsic(Call::extract_buffer_max)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        rhs << "(((buffer_t *)(" << a0 << "))->min[" << a1 << "] + " <<
            "((buffer_t *)(" << a0 << "))->extent[" << a1 << "] - 1)";
    } else if (op->is_intrinsic(Call::extract_buffer_min)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        rhs << "((buffer_t *)(" << a0 << "))->min[" << a1 << "]";
    } else if (op->is_intrinsic(Call::extract_buffer_host)) {
        internal_assert(op->args.size() == 1);
        string a0 = print_expr(op->args[0]);
        rhs << "((buffer_t *)(" << a0 << "))->host";
    } else if (op->is_intrinsic(Call::set_host_dirty)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        do_indent();
        stream << "((buffer_t *)(" << a0 << "))->host_dirty = " << a1 << ";\n";
        rhs << "0";
    } else if (op->is_intrinsic(Call::set_dev_dirty)) {
        internal_assert(op->args.size() == 2);
        string a0 = print_expr(op->args[0]);
        string a1 = print_expr(op->args[1]);
        do_indent();
        stream << "((buffer_t *)(" << a0 << "))->dev_dirty = " << a1 << ";\n";
        rhs << "0";
    } else if (op->is_intrinsic(Call::abs)) {
        internal_assert(op->args.size() == 1);
        Expr a0 = op->args[0];
        rhs << print_expr(cast(op->type, select(a0 > 0, a0, -a0)));
    } else if (op->is_intrinsic(Call::memoize_expr)) {
        internal_assert(op->args.size() >= 1);
        string arg = print_expr(op->args[0]);
        rhs << "(" << arg << ")";
    } else if (op->is_intrinsic(Call::copy_memory)) {
        internal_assert(op->args.size() == 3);
        string dest = print_expr(op->args[0]);
        string src = print_expr(op->args[1]);
        string size = print_expr(op->args[2]);
        rhs << "memcpy(" << dest << ", " << src << ", " << size << ")";
    } else if (op->is_intrinsic(Call::make_struct)) {
        // Emit a line something like:
        // struct {const int f_0, const char f_1, const int f_2} foo = {3, 'c', 4};

        // Get the args
        vector<string> values;
        for (size_t i = 0; i < op->args.size(); i++) {
            values.push_back(print_expr(op->args[i]));
        }
        do_indent();
        stream << "struct {";
        // List the types.
        for (size_t i = 0; i < op->args.size(); i++) {
            stream << "const " << print_type(op->args[i].type()) << " f_" << i << "; ";
        }
        string struct_name = unique_name('s');
        stream << "}  " << struct_name << " = {";
        // List the values.
        for (size_t i = 0; i < op->args.size(); i++) {
            if (i > 0) stream << ", ";
            stream << values[i];
        }
        stream << "};\n";
        // Return a pointer to it.
        rhs << "(&" << struct_name << ")";
    } else if (op->is_intrinsic(Call::stringify)) {
        // Rewrite to an snprintf
        vector<string> printf_args;
        string format_string = "";
        for (size_t i = 0; i < op->args.size(); i++) {
            Type t = op->args[i].type();
            printf_args.push_back(print_expr(op->args[i]));
            if (t.is_int()) {
                format_string += "%lld";
                printf_args[i] = "(long long)(" + printf_args[i] + ")";
            } else if (t.is_uint()) {
                format_string += "%llu";
                printf_args[i] = "(long long unsigned)(" + printf_args[i] + ")";
            } else if (t.is_float()) {
                if (t.bits() == 32) {
                    format_string += "%f";
                } else {
                    format_string += "%e";
                }
            } else if (op->args[i].as<StringImm>()) {
                format_string += "%s";
            } else {
                internal_assert(t.is_handle());
                format_string += "%p";
            }

        }
        string buf_name = unique_name('b');
        do_indent();
        stream << "char " << buf_name << "[1024];\n";
        do_indent();
        stream << "snprintf(" << buf_name << ", 1024, \"" << format_string << "\"";
        for (size_t i = 0; i < printf_args.size(); i++) {
            stream << ", " << printf_args[i];
        }
        stream << ");\n";
        rhs << buf_name;

    } else if (op->is_intrinsic(Call::register_destructor)) {
        internal_assert(op->args.size() == 2);
        const StringImm *fn = op->args[0].as<StringImm>();
        internal_assert(fn);
        string arg = print_expr(op->args[1]);

        string call =
            fn->value + "(" +
            (have_user_context ? "__user_context_, " : "nullptr, ")
            + "arg);";

        do_indent();
        // Make a struct on the stack that calls the given function as a destructor
        string struct_name = unique_name('s');
        string instance_name = unique_name('d');
        stream << "struct " << struct_name << "{ "
               << "void *arg; "
               << struct_name << "(void *a) : arg((void *)a) {} "
               << "~" << struct_name << "() {" << call << "}"
               << "} " << instance_name << "(" << arg << ");\n";
        rhs << print_expr(0);
    } else if (op->call_type == Call::Intrinsic ||
               op->call_type == Call::PureIntrinsic) {
        // TODO: other intrinsics
        internal_error << "Unhandled intrinsic in C backend: " << op->name << '\n';

    } else {
        std::string name;
        if (op->call_type == Call::ExternCPlusPlus) {
            std::vector<std::string> namespaces;
            name = extract_namespaces(op->name, namespaces);
        } else {
            name = op->name;
        }

        // Generic calls
        vector<string> args(op->args.size());
        for (size_t i = 0; i < op->args.size(); i++) {
            args[i] = print_expr(op->args[i]);
        }
        rhs << name << "(";

        if (function_takes_user_context(op->name)) {
            rhs << (have_user_context ? "__user_context_, " : "nullptr, ");
        }

        for (size_t i = 0; i < op->args.size(); i++) {
            if (i > 0) rhs << ", ";
            rhs << args[i];
        }
        rhs << ")";
    }

    print_assignment(op->type, rhs.str());
}

void CodeGen_C::visit(const Load *op) {

    Type t = op->type;
    bool type_cast_needed =
        !allocations.contains(op->name) ||
        allocations.get(op->name).type != t;

    ostringstream rhs;
    if (type_cast_needed) {
        rhs << "(("
            << print_type(op->type)
            << " *)"
            << print_name(op->name)
            << ")";
    } else {
        rhs << print_name(op->name);
    }
    rhs << "["
        << print_expr(op->index)
        << "]";

    print_assignment(op->type, rhs.str());
}

void CodeGen_C::visit(const Store *op) {

    Type t = op->value.type();

    bool type_cast_needed =
        t.is_handle() ||
        !allocations.contains(op->name) ||
        allocations.get(op->name).type != t;

    string id_index = print_expr(op->index);
    string id_value = print_expr(op->value);
    do_indent();

    if (type_cast_needed) {
        stream << "((const "
               << print_type(t)
               << " *)"
               << print_name(op->name)
               << ")";
    } else {
        stream << print_name(op->name);
    }
    stream << "["
           << id_index
           << "] = "
           << id_value
           << ";\n";

    cache.clear();
}

void CodeGen_C::visit(const Let *op) {
    string id_value = print_expr(op->value);
    Expr new_var = Variable::make(op->value.type(), id_value);
    Expr body = substitute(op->name, new_var, op->body);
    print_expr(body);
}

void CodeGen_C::visit(const Select *op) {
    ostringstream rhs;
    string true_val = print_expr(op->true_value);
    string false_val = print_expr(op->false_value);
    string cond = print_expr(op->condition);
    rhs << "(" << print_type(op->type) << ")"
        << "(" << cond
        << " ? " << true_val
        << " : " << false_val
        << ")";
    print_assignment(op->type, rhs.str());
}

void CodeGen_C::visit(const LetStmt *op) {
    string id_value = print_expr(op->value);
    Expr new_var = Variable::make(op->value.type(), id_value);
    Stmt body = substitute(op->name, new_var, op->body);
    body.accept(this);
}

void CodeGen_C::visit(const AssertStmt *op) {
    string id_cond = print_expr(op->condition);

    do_indent();
    // Halide asserts have different semantics to C asserts.  They're
    // supposed to clean up and make the containing function return
    // -1, so we can't use the C version of assert. Instead we convert
    // to an if statement.

    stream << "if (!" << id_cond << ") ";
    open_scope();
    string id_msg = print_expr(op->message);
    do_indent();
    stream << "return " << id_msg << ";\n";
    close_scope("");
}

void CodeGen_C::visit(const ProducerConsumer *op) {

    do_indent();
    stream << "// produce " << op->name << '\n';
    print_stmt(op->produce);

    if (op->update.defined()) {
        do_indent();
        stream << "// update " << op->name << '\n';
        print_stmt(op->update);
    }

    do_indent();
    stream << "// consume " << op->name << '\n';
    print_stmt(op->consume);
}

void CodeGen_C::visit(const For *op) {
    if (op->for_type == ForType::Parallel) {
        do_indent();
        stream << "#pragma omp parallel for\n";
    } else {
        internal_assert(op->for_type == ForType::Serial)
            << "Can only emit serial or parallel for loops to C\n";
    }

    string id_min = print_expr(op->min);
    string id_extent = print_expr(op->extent);

    do_indent();
    stream << "for (int "
           << print_name(op->name)
           << " = " << id_min
           << "; "
           << print_name(op->name)
           << " < " << id_min
           << " + " << id_extent
           << "; "
           << print_name(op->name)
           << "++)\n";

    open_scope();
    op->body.accept(this);
    close_scope("for " + print_name(op->name));

}

void CodeGen_C::visit(const Provide *op) {
    internal_error << "Cannot emit Provide statements as C\n";
}

void CodeGen_C::visit(const Allocate *op) {
    open_scope();

    // For sizes less than 8k, do a stack allocation
    bool on_stack = false;
    int32_t constant_size;
    string size_id;
    if (op->new_expr.defined()) {
        Allocation alloc;
        alloc.type = op->type;
        alloc.free_function = op->free_function;
        allocations.push(op->name, alloc);
        heap_allocations.push(op->name, 0);
        stream << print_type(op->type) << "*" << print_name(op->name) << " = (" << print_expr(op->new_expr) << ");\n";
    } else {
        constant_size = op->constant_allocation_size();
        if (constant_size > 0) {
            int64_t stack_bytes = constant_size * op->type.bytes();

            if (stack_bytes > ((int64_t(1) << 31) - 1)) {
                user_error << "Total size for allocation "
                           << op->name << " is constant but exceeds 2^31 - 1.\n";
            } else {
                size_id = print_expr(Expr(static_cast<int32_t>(constant_size)));
                if (can_allocation_fit_on_stack(stack_bytes)) {
                    on_stack = true;
                }
            }
        } else {
            // Check that the allocation is not scalar (if it were scalar
            // it would have constant size).
            internal_assert(op->extents.size() > 0);

            size_id = print_assignment(Int(64), print_expr(op->extents[0]));

            for (size_t i = 1; i < op->extents.size(); i++) {
                // Make the code a little less cluttered for two-dimensional case
                string new_size_id_rhs;
                string next_extent = print_expr(op->extents[i]);
                if (i > 1) {
                    new_size_id_rhs =  "(" + size_id + " > ((int64_t(1) << 31) - 1)) ? " + size_id + " : (" + size_id + " * " + next_extent + ")";
                } else {
                    new_size_id_rhs = size_id + " * " + next_extent;
                }
                size_id = print_assignment(Int(64), new_size_id_rhs);
            }
            do_indent();
            stream << "if ((" << size_id << " > ((int64_t(1) << 31) - 1)) || ((" << size_id <<
              " * sizeof(" << print_type(op->type) << ")) > ((int64_t(1) << 31) - 1)))\n";
            open_scope();
            do_indent();
            stream << "halide_error("
                   << (have_user_context ? "__user_context_" : "nullptr")
                   << ", \"32-bit signed overflow computing size of allocation "
                   << op->name << "\\n\");\n";
            do_indent();
            stream << "return -1;\n";
            close_scope("overflow test " + op->name);
        }

        // Check the condition to see if this allocation should actually be created.
        // If the allocation is on the stack, the only condition we can respect is
        // unconditional false (otherwise a non-constant-sized array declaration
        // will be generated).
        if (!on_stack || is_zero(op->condition)) {
            Expr conditional_size = Select::make(op->condition,
                                                 Var(size_id),
                                                 Expr(static_cast<int32_t>(0)));
            conditional_size = simplify(conditional_size);
            size_id = print_assignment(Int(64), print_expr(conditional_size));
        }

        Allocation alloc;
        alloc.type = op->type;
        allocations.push(op->name, alloc);

        do_indent();
        stream << print_type(op->type) << ' ';

        if (on_stack) {
            stream << print_name(op->name)
                   << "[" << size_id << "];\n";
        } else {
            stream << "*"
                   << print_name(op->name)
                   << " = ("
                   << print_type(op->type)
                   << " *)halide_malloc("
                   << (have_user_context ? "__user_context_" : "nullptr")
                   << ", sizeof("
                   << print_type(op->type)
                   << ")*" << size_id << ");\n";
            heap_allocations.push(op->name, 0);
        }
    }

    op->body.accept(this);

    // Should have been freed internally
    internal_assert(!allocations.contains(op->name));

    close_scope("alloc " + print_name(op->name));
}

void CodeGen_C::visit(const Free *op) {
    if (heap_allocations.contains(op->name)) {
        string free_function = allocations.get(op->name).free_function;
        if (free_function.empty()) {
            free_function = "halide_free";
        }

        do_indent();
        stream << free_function << "("
               << (have_user_context ? "__user_context_, " : "nullptr, ")
               << print_name(op->name)
               << ");\n";
        heap_allocations.pop(op->name);
    }
    allocations.pop(op->name);
}

void CodeGen_C::visit(const Realize *op) {
    internal_error << "Cannot emit realize statements to C\n";
}

void CodeGen_C::visit(const IfThenElse *op) {
    string cond_id = print_expr(op->condition);

    do_indent();
    stream << "if (" << cond_id << ")\n";
    open_scope();
    op->then_case.accept(this);
    close_scope("if " + cond_id);

    if (op->else_case.defined()) {
        do_indent();
        stream << "else\n";
        open_scope();
        op->else_case.accept(this);
        close_scope("if " + cond_id + " else");
    }
}

void CodeGen_C::visit(const Evaluate *op) {
    if (is_const(op->value)) return;
    string id = print_expr(op->value);
    do_indent();
    stream << "(void)" << id << ";\n";
}

void CodeGen_C::test() {
    Argument buffer_arg("buf", Argument::OutputBuffer, Int(32), 3);
    Argument float_arg("alpha", Argument::InputScalar, Float(32), 0);
    Argument int_arg("beta", Argument::InputScalar, Int(32), 0);
    Argument user_context_arg("__user_context", Argument::InputScalar, Handle(), 0);
    vector<Argument> args(4);
    args[0] = buffer_arg;
    args[1] = float_arg;
    args[2] = int_arg;
    args[3] = user_context_arg;
    Var x("x");
    Param<float> alpha("alpha");
    Param<int> beta("beta");
    Expr e = Select::make(alpha > 4.0f, print_when(x < 1, 3), 2);
    Stmt s = Store::make("buf", e, x, Parameter());
    s = LetStmt::make("x", beta+1, s);
    s = Block::make(s, Free::make("tmp.stack"));
    s = Allocate::make("tmp.stack", Int(32), {127}, const_true(), s);
    s = Block::make(s, Free::make("tmp.heap"));
    s = Allocate::make("tmp.heap", Int(32), {43, beta}, const_true(), s);

    Module m("", get_host_target());
    m.append(LoweredFunc("test1", args, s, LoweredFunc::External));

    ostringstream source;
    {
        CodeGen_C cg(source, CodeGen_C::CImplementation);
        cg.compile(m);
    }

    string src = source.str();
    string correct_source =
        headers +
        buffer_t_definition +
        "struct halide_filter_metadata_t;\n" +
        globals +
        "#ifndef HALIDE_FUNCTION_ATTRS\n"
        "#define HALIDE_FUNCTION_ATTRS\n"
        "#endif\n"
        "#ifdef __cplusplus\n"
        "extern \"C\" {\n"
        "#endif\n"
        "\n\n"
        "int test1(buffer_t *_buf_buffer, float _alpha, int32_t _beta, const void *__user_context) HALIDE_FUNCTION_ATTRS {\n"
        " int32_t *_buf = (int32_t *)(_buf_buffer->host);\n"
        " (void)_buf;\n"
        " const bool _buf_host_and_dev_are_null = (_buf_buffer->host == nullptr) && (_buf_buffer->dev == 0);\n"
        " (void)_buf_host_and_dev_are_null;\n"
        " const int32_t _buf_min_0 = _buf_buffer->min[0];\n"
        " (void)_buf_min_0;\n"
        " const int32_t _buf_min_1 = _buf_buffer->min[1];\n"
        " (void)_buf_min_1;\n"
        " const int32_t _buf_min_2 = _buf_buffer->min[2];\n"
        " (void)_buf_min_2;\n"
        " const int32_t _buf_min_3 = _buf_buffer->min[3];\n"
        " (void)_buf_min_3;\n"
        " const int32_t _buf_extent_0 = _buf_buffer->extent[0];\n"
        " (void)_buf_extent_0;\n"
        " const int32_t _buf_extent_1 = _buf_buffer->extent[1];\n"
        " (void)_buf_extent_1;\n"
        " const int32_t _buf_extent_2 = _buf_buffer->extent[2];\n"
        " (void)_buf_extent_2;\n"
        " const int32_t _buf_extent_3 = _buf_buffer->extent[3];\n"
        " (void)_buf_extent_3;\n"
        " const int32_t _buf_stride_0 = _buf_buffer->stride[0];\n"
        " (void)_buf_stride_0;\n"
        " const int32_t _buf_stride_1 = _buf_buffer->stride[1];\n"
        " (void)_buf_stride_1;\n"
        " const int32_t _buf_stride_2 = _buf_buffer->stride[2];\n"
        " (void)_buf_stride_2;\n"
        " const int32_t _buf_stride_3 = _buf_buffer->stride[3];\n"
        " (void)_buf_stride_3;\n"
        " const int32_t _buf_elem_size = _buf_buffer->elem_size;\n"
        " (void)_buf_elem_size;\n"
        " {\n"
        "  int64_t _0 = 43;\n"
        "  int64_t _1 = _0 * _beta;\n"
        "  if ((_1 > ((int64_t(1) << 31) - 1)) || ((_1 * sizeof(int32_t)) > ((int64_t(1) << 31) - 1)))\n"
        "  {\n"
        "   halide_error(__user_context_, \"32-bit signed overflow computing size of allocation tmp.heap\\n\");\n"
        "   return -1;\n"
        "  } // overflow test tmp.heap\n"
        "  int64_t _2 = _1;\n"
        "  int32_t *_tmp_heap = (int32_t *)halide_malloc(__user_context_, sizeof(int32_t)*_2);\n"
        "  {\n"
        "   int32_t _tmp_stack[127];\n"
        "   int32_t _3 = _beta + 1;\n"
        "   int32_t _4;\n"
        "   bool _5 = _3 < 1;\n"
        "   if (_5)\n"
        "   {\n"
        "    char b0[1024];\n"
        "    snprintf(b0, 1024, \"%lld%s\", (long long)(3), \"\\n\");\n"
        "    char const *_6 = b0;\n"
        "    int32_t _7 = halide_print(__user_context_, _6);\n"
        "    int32_t _8 = (_7, 3);\n"
        "    _4 = _8;\n"
        "   } // if _5\n"
        "   else\n"
        "   {\n"
        "    _4 = 3;\n"
        "   } // if _5 else\n"
        "   int32_t _9 = _4;\n"
        "   bool _10 = _alpha > float_from_bits(1082130432 /* 4 */);\n"
        "   int32_t _11 = (int32_t)(_10 ? _9 : 2);\n"
        "   _buf[_3] = _11;\n"
        "  } // alloc _tmp_stack\n"
        "  halide_free(__user_context_, _tmp_heap);\n"
        " } // alloc _tmp_heap\n"
        " return 0;\n"
        "}\n"
        "#ifdef __cplusplus\n"
        "}  // extern \"C\"\n"
        "#endif\n";
;
    if (src != correct_source) {
        int diff = 0;
        while (src[diff] == correct_source[diff]) diff++;
        int diff_end = diff + 1;
        while (diff > 0 && src[diff] != '\n') diff--;
        while (diff_end < (int)src.size() && src[diff_end] != '\n') diff_end++;

        internal_error
            << "Correct source code:\n" << correct_source
            << "Actual source code:\n" << src
            << "\nDifference starts at: " << src.substr(diff, diff_end - diff) << "\n";

    }


    std::cout << "CodeGen_C test passed\n";
}

}
}
back to top