https://github.com/halide/Halide
Tip revision: 9903d2bf77a7f88d50ddb68668b9e462201a5707 authored by Andrew Adams on 27 August 2020, 02:14:59 UTC
Add comment explaining why we don't do dynamic tracking when no upper bound too
Add comment explaining why we don't do dynamic tracking when no upper bound too
Tip revision: 9903d2b
DebugToFile.cpp
#include <map>
#include <sstream>
#include <vector>
#include "DebugToFile.h"
#include "Function.h"
#include "IRMutator.h"
#include "IROperator.h"
namespace Halide {
namespace Internal {
using std::map;
using std::ostringstream;
using std::string;
using std::vector;
class DebugToFile : public IRMutator {
const map<string, Function> &env;
using IRMutator::visit;
Stmt visit(const Realize *op) override {
map<string, Function>::const_iterator iter = env.find(op->name);
if (iter != env.end() && !iter->second.debug_file().empty()) {
Function f = iter->second;
vector<Expr> args;
user_assert(op->types.size() == 1)
<< "debug_to_file doesn't handle functions with multiple values yet\n";
// The name of the file
args.emplace_back(f.debug_file());
// Inject loads to the corners of the function so that any
// passes doing further analysis of buffer use understand
// what we're doing (e.g. so we trigger a copy-back from a
// device pointer).
Expr num_elements = 1;
for (size_t i = 0; i < op->bounds.size(); i++) {
num_elements *= op->bounds[i].extent;
}
int type_code = 0;
Type t = op->types[0];
if (t == Float(32)) {
type_code = 0;
} else if (t == Float(64)) {
type_code = 1;
} else if (t == UInt(8) || t == UInt(1)) {
type_code = 2;
} else if (t == Int(8)) {
type_code = 3;
} else if (t == UInt(16)) {
type_code = 4;
} else if (t == Int(16)) {
type_code = 5;
} else if (t == UInt(32)) {
type_code = 6;
} else if (t == Int(32)) {
type_code = 7;
} else if (t == UInt(64)) {
type_code = 8;
} else if (t == Int(64)) {
type_code = 9;
} else {
user_error << "Type " << t << " not supported for debug_to_file\n";
}
args.emplace_back(type_code);
Expr buf = Variable::make(Handle(), f.name() + ".buffer");
args.push_back(buf);
Expr call = Call::make(Int(32), Call::debug_to_file, args, Call::Intrinsic);
string call_result_name = unique_name("debug_to_file_result");
Expr call_result_var = Variable::make(Int(32), call_result_name);
Stmt body = AssertStmt::make(call_result_var == 0,
Call::make(Int(32), "halide_error_debug_to_file_failed",
{f.name(), f.debug_file(), call_result_var},
Call::Extern));
body = LetStmt::make(call_result_name, call, body);
body = Block::make(mutate(op->body), body);
return Realize::make(op->name, op->types, op->memory_type, op->bounds, op->condition, body);
} else {
return IRMutator::visit(op);
}
}
public:
DebugToFile(const map<string, Function> &e)
: env(e) {
}
};
class RemoveDummyRealizations : public IRMutator {
const vector<Function> &outputs;
using IRMutator::visit;
Stmt visit(const Realize *op) override {
for (Function f : outputs) {
if (op->name == f.name()) {
return mutate(op->body);
}
}
return IRMutator::visit(op);
}
public:
RemoveDummyRealizations(const vector<Function> &o)
: outputs(o) {
}
};
class AddDummyRealizations : public IRMutator {
const vector<Function> &outputs;
using IRMutator::visit;
Stmt visit(const ProducerConsumer *op) override {
Stmt s = IRMutator::visit(op);
for (Function out : outputs) {
if (op->name == out.name()) {
std::vector<Range> output_bounds;
for (int i = 0; i < out.dimensions(); i++) {
string dim = std::to_string(i);
Expr min = Variable::make(Int(32), out.name() + ".min." + dim);
Expr extent = Variable::make(Int(32), out.name() + ".extent." + dim);
output_bounds.emplace_back(min, extent);
}
return Realize::make(out.name(),
out.output_types(),
MemoryType::Auto,
output_bounds,
const_true(),
s);
}
}
return s;
}
public:
AddDummyRealizations(const vector<Function> &o)
: outputs(o) {
}
};
Stmt debug_to_file(Stmt s, const vector<Function> &outputs, const map<string, Function> &env) {
// Temporarily wrap the produce nodes for the output functions in
// realize nodes so that we know when to write the debug outputs.
s = AddDummyRealizations(outputs).mutate(s);
s = DebugToFile(env).mutate(s);
// Remove the realize node we wrapped around the output
s = RemoveDummyRealizations(outputs).mutate(s);
return s;
}
} // namespace Internal
} // namespace Halide