https://github.com/halide/Halide
Raw File
Tip revision: 1aac1064793f721d83fc3d2c26e35830244d5c1c authored by aekul on 18 August 2020, 20:29:55 UTC
Add script and bilateral_grid
Tip revision: 1aac106
DebugToFile.cpp
#include <map>
#include <sstream>
#include <vector>

#include "DebugToFile.h"
#include "Function.h"
#include "IRMutator.h"
#include "IROperator.h"

namespace Halide {
namespace Internal {

using std::map;
using std::ostringstream;
using std::string;
using std::vector;

class DebugToFile : public IRMutator {
    const map<string, Function> &env;

    using IRMutator::visit;

    Stmt visit(const Realize *op) override {
        map<string, Function>::const_iterator iter = env.find(op->name);
        if (iter != env.end() && !iter->second.debug_file().empty()) {
            Function f = iter->second;
            vector<Expr> args;

            user_assert(op->types.size() == 1)
                << "debug_to_file doesn't handle functions with multiple values yet\n";

            // The name of the file
            args.emplace_back(f.debug_file());

            // Inject loads to the corners of the function so that any
            // passes doing further analysis of buffer use understand
            // what we're doing (e.g. so we trigger a copy-back from a
            // device pointer).
            Expr num_elements = 1;
            for (size_t i = 0; i < op->bounds.size(); i++) {
                num_elements *= op->bounds[i].extent;
            }

            int type_code = 0;
            Type t = op->types[0];
            if (t == Float(32)) {
                type_code = 0;
            } else if (t == Float(64)) {
                type_code = 1;
            } else if (t == UInt(8) || t == UInt(1)) {
                type_code = 2;
            } else if (t == Int(8)) {
                type_code = 3;
            } else if (t == UInt(16)) {
                type_code = 4;
            } else if (t == Int(16)) {
                type_code = 5;
            } else if (t == UInt(32)) {
                type_code = 6;
            } else if (t == Int(32)) {
                type_code = 7;
            } else if (t == UInt(64)) {
                type_code = 8;
            } else if (t == Int(64)) {
                type_code = 9;
            } else {
                user_error << "Type " << t << " not supported for debug_to_file\n";
            }
            args.emplace_back(type_code);

            Expr buf = Variable::make(Handle(), f.name() + ".buffer");
            args.push_back(buf);

            Expr call = Call::make(Int(32), Call::debug_to_file, args, Call::Intrinsic);
            string call_result_name = unique_name("debug_to_file_result");
            Expr call_result_var = Variable::make(Int(32), call_result_name);
            Stmt body = AssertStmt::make(call_result_var == 0,
                                         Call::make(Int(32), "halide_error_debug_to_file_failed",
                                                    {f.name(), f.debug_file(), call_result_var},
                                                    Call::Extern));
            body = LetStmt::make(call_result_name, call, body);
            body = Block::make(mutate(op->body), body);

            return Realize::make(op->name, op->types, op->memory_type, op->bounds, op->condition, body);
        } else {
            return IRMutator::visit(op);
        }
    }

public:
    DebugToFile(const map<string, Function> &e)
        : env(e) {
    }
};

class RemoveDummyRealizations : public IRMutator {
    const vector<Function> &outputs;

    using IRMutator::visit;

    Stmt visit(const Realize *op) override {
        for (Function f : outputs) {
            if (op->name == f.name()) {
                return mutate(op->body);
            }
        }
        return IRMutator::visit(op);
    }

public:
    RemoveDummyRealizations(const vector<Function> &o)
        : outputs(o) {
    }
};

class AddDummyRealizations : public IRMutator {
    const vector<Function> &outputs;

    using IRMutator::visit;

    Stmt visit(const ProducerConsumer *op) override {
        Stmt s = IRMutator::visit(op);
        for (Function out : outputs) {
            if (op->name == out.name()) {
                std::vector<Range> output_bounds;
                for (int i = 0; i < out.dimensions(); i++) {
                    string dim = std::to_string(i);
                    Expr min = Variable::make(Int(32), out.name() + ".min." + dim);
                    Expr extent = Variable::make(Int(32), out.name() + ".extent." + dim);
                    output_bounds.emplace_back(min, extent);
                }
                return Realize::make(out.name(),
                                     out.output_types(),
                                     MemoryType::Auto,
                                     output_bounds,
                                     const_true(),
                                     s);
            }
        }
        return s;
    }

public:
    AddDummyRealizations(const vector<Function> &o)
        : outputs(o) {
    }
};

Stmt debug_to_file(Stmt s, const vector<Function> &outputs, const map<string, Function> &env) {
    // Temporarily wrap the produce nodes for the output functions in
    // realize nodes so that we know when to write the debug outputs.
    s = AddDummyRealizations(outputs).mutate(s);

    s = DebugToFile(env).mutate(s);

    // Remove the realize node we wrapped around the output
    s = RemoveDummyRealizations(outputs).mutate(s);

    return s;
}

}  // namespace Internal
}  // namespace Halide
back to top