Revision bc30d6f83513724e9f832098d9e8d5fb87ebc9fc authored by Steven Johnson on 03 August 2023, 00:18:33 UTC, committed by GitHub on 03 August 2023, 00:18:33 UTC
* Revise labels on autoscheduler tests This is step 1 in fixing https://github.com/halide/Halide/issues/7731: it replaces the `autoschedulers` tag with more granular ones, so that we can modify the build script to test the right autoscheduler(s) for a given backend. (Note that the `autoschedulers` tag was unused by the buildbots, which only used the generic `auto_schedule` tag.) Step 2 will be to modify the buildbot script after this lands to use the new tags above. Step 3 will be to remove the `auto_schedule` tag. * Fix anderson2021 labels
1 parent 734df3f
StmtToViz.cpp
#include "StmtToViz.h"
#include "Debug.h"
#include "Error.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "IRVisitor.h"
#include "Module.h"
#include "Scope.h"
#include "Substitute.h"
#include "Util.h"
#include <cstdio>
#include <fstream>
#include <iostream>
#include <iterator>
#include <regex>
#include <sstream>
#include <utility>
namespace Halide {
namespace Internal {
extern "C" unsigned char halide_html_template_StmtToViz_dependencies[];
extern "C" unsigned char halide_html_template_StmtToViz_stylesheet[];
extern "C" unsigned char halide_html_template_StmtToViz_javascript[];
// Classes defined within this file
class CostModel;
class AssemblyInfo;
template<typename T>
class HTMLCodePrinter;
class HTMLVisualizationPrinter;
class IRVisualizer;
/** IRCostModel
* A basic cost model for Halide IR. Estimates computation
* cost through simple op-counting and data-movement cost
* by counting the number of bits being moved.
*/
class IRCostModel : public IRVisitor {
public:
IRCostModel()
= default;
// Pre-compute all costs to avoid repeated work
void compute_all_costs(const Module &m) {
// Compute all node costs
for (const auto &fn : m.functions()) {
fn.body.accept(this);
}
// Compute the max cost for each category
max_compute_cost = -1;
for (auto const &entry : compute_cost) {
max_compute_cost = std::max(entry.second, max_compute_cost);
}
max_data_cost = -1;
for (auto const &entry : data_cost) {
max_data_cost = std::max(entry.second, max_data_cost);
}
max_compute_cost_inclusive = -1;
for (auto const &entry : compute_cost_inclusive) {
max_compute_cost_inclusive = std::max(entry.second, max_compute_cost_inclusive);
}
max_data_cost_inclusive = -1;
for (auto const &entry : data_cost_inclusive) {
max_data_cost_inclusive = std::max(entry.second, max_data_cost_inclusive);
}
}
// Returns the compute cost of a node (estimated using simple op-counting)
int get_compute_cost(const IRNode *node, bool include_subtree_cost) {
internal_assert(node != nullptr) << "IRCostModel::get_compute_cost(): node is nullptr\n";
int cost = -1;
if (compute_cost.count(node)) {
cost = include_subtree_cost ? compute_cost_inclusive[node] : compute_cost[node];
} else {
internal_error << "IRCostModel::get_compute_cost(): cost lookup failed\n";
}
internal_assert(cost >= 0) << "Cost must not be negative.\n";
return cost;
}
// Returns the data movement cost of a node (the number of bits moved in load/store/shuffle ops)
int get_data_movement_cost(const IRNode *node, bool include_subtree_cost) {
internal_assert(node != nullptr) << "IRCostModel::get_data_movement_cost(): node is nullptr\n";
int cost = -1;
if (compute_cost.count(node)) {
cost = include_subtree_cost ? data_cost_inclusive[node] : data_cost[node];
} else {
internal_error << "IRCostModel::get_data_movement_cost(): cost lookup failed\n";
}
internal_assert(cost >= 0) << "Cost cost must not be negative.\n";
return cost;
}
// Returns the max compute cost of any node in the program
int get_max_compute_cost(bool include_subtree_cost) {
return include_subtree_cost ? max_compute_cost_inclusive : max_compute_cost;
}
// Returns the max data movement cost of any node in the program
int get_max_data_movement_cost(bool include_subtree_cost) {
return include_subtree_cost ? max_data_cost_inclusive : max_data_cost;
}
private:
// Cost database. We track two costs:
// - The line cost of a node is the sum of the node cost
// plus the cost of any children that are printed on
// a single line (since we display cost by each line in
// the program)
// - The inclusive cost is the cost of the entire sub-tree.
// We display this cost when the user collapses a program
// block in the IR.
std::map<const IRNode *, int> compute_cost;
std::map<const IRNode *, int> data_cost;
std::map<const IRNode *, int> compute_cost_inclusive;
std::map<const IRNode *, int> data_cost_inclusive;
// We also track the max costs to determine the cost color
// intensity for a given line of code
int max_compute_cost = -1;
int max_data_cost = -1;
int max_compute_cost_inclusive = -1;
int max_data_cost_inclusive = -1;
/* Utility functions to store node costs in the cost database */
void set_compute_costs(const IRNode *node, int node_cost, const std::vector<const IRNode *> &child_nodes) {
set_compute_costs(node, node_cost, child_nodes, child_nodes);
}
void set_compute_costs(const IRNode *node, int node_cost, const std::vector<const IRNode *> &child_nodes, const std::vector<const IRNode *> &inline_child_nodes) {
int subtree_cost = 0;
for (const IRNode *child_node : child_nodes) {
// Certain child nodes can be null. Ex: else-case
// in an if statement
if (child_node) {
subtree_cost += get_compute_cost(child_node, true);
}
}
int line_cost = node_cost;
for (const IRNode *child_node : inline_child_nodes) {
if (child_node) {
line_cost += get_compute_cost(child_node, true);
}
}
compute_cost[node] = line_cost;
compute_cost_inclusive[node] = node_cost + subtree_cost;
}
void set_data_costs(const IRNode *node, int node_cost, const std::vector<const IRNode *> &child_nodes) {
set_data_costs(node, node_cost, child_nodes, child_nodes);
}
void set_data_costs(const IRNode *node, int node_cost, const std::vector<const IRNode *> &child_nodes, const std::vector<const IRNode *> &inline_child_nodes) {
int subtree_cost = 0;
for (const IRNode *child_node : child_nodes) {
// Certain child nodes can be null. Ex: else-case
// in an if statement
if (child_node) {
subtree_cost += get_data_movement_cost(child_node, true);
}
}
int line_cost = node_cost;
for (const IRNode *child_node : inline_child_nodes) {
if (child_node) {
line_cost += get_data_movement_cost(child_node, true);
}
}
data_cost[node] = line_cost;
data_cost_inclusive[node] = node_cost + subtree_cost;
}
/* Visitor functions for each IR node */
using IRVisitor::visit;
void visit(const IntImm *op) override {
set_compute_costs(op, 0, {});
set_data_costs(op, 0, {});
}
void visit(const UIntImm *op) override {
set_compute_costs(op, 0, {});
set_data_costs(op, 0, {});
}
void visit(const FloatImm *op) override {
set_compute_costs(op, 0, {});
set_data_costs(op, 0, {});
}
void visit(const StringImm *op) override {
set_compute_costs(op, 0, {});
set_data_costs(op, 0, {});
}
void visit(const Variable *op) override {
set_compute_costs(op, 0, {});
set_data_costs(op, 0, {});
}
void visit(const Cast *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->value.get()});
set_data_costs(op, 0, {op->value.get()});
}
void visit(const Reinterpret *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->value.get()});
set_data_costs(op, 0, {op->value.get()});
}
void visit(const Add *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Sub *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Mul *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Div *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Mod *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Min *op) override {
IRVisitor::visit(op);
// This cost model treats min as a single op
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Max *op) override {
IRVisitor::visit(op);
// This cost model treats max as a single op
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const EQ *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const NE *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const LT *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const LE *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const GT *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const GE *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const And *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Or *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get(), op->b.get()});
set_data_costs(op, 0, {op->a.get(), op->b.get()});
}
void visit(const Not *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->a.get()});
set_data_costs(op, 0, {op->a.get()});
}
void visit(const Select *op) override {
IRVisitor::visit(op);
set_compute_costs(op, op->type.lanes(), {op->condition.get(), op->true_value.get(), op->false_value.get()});
set_data_costs(op, 0, {op->condition.get(), op->true_value.get(), op->false_value.get()});
}
void visit(const Load *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->predicate.get(), op->index.get()});
set_data_costs(op, op->type.bits() * op->type.lanes(), {op->predicate.get(), op->index.get()});
}
void visit(const Ramp *op) override {
// The cost of a Ramp is higher when the stride is not 1,
// but currently the cost model does not consider such cases
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->base.get(), op->stride.get()});
set_data_costs(op, 0, {op->base.get(), op->stride.get()});
}
void visit(const Broadcast *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 1, {op->value.get()});
set_data_costs(op, 0, {op->value.get()});
}
void visit(const Call *op) override {
IRVisitor::visit(op);
std::vector<const IRNode *> args;
for (const auto &arg : op->args) {
args.push_back(arg.get());
}
set_compute_costs(op, 1, args);
// Currently there is no special handling
// for intrinsics such as `prefetch`
set_data_costs(op, 0, args);
}
void visit(const Let *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->value.get(), op->body.get()});
set_data_costs(op, 0, {op->value.get(), op->body.get()});
}
void visit(const Shuffle *op) override {
IRVisitor::visit(op);
std::vector<const IRNode *> args;
for (const auto &arg : op->vectors) {
args.push_back(arg.get());
}
set_compute_costs(op, 0, args);
set_data_costs(op, op->type.bits() * op->type.lanes(), args);
}
void visit(const VectorReduce *op) override {
IRVisitor::visit(op);
const int factor = op->value.type().lanes() / op->type.lanes();
const int op_count = op->type.lanes() * (factor - 1);
set_compute_costs(op, op_count, {op->value.get()});
set_data_costs(op, 0, {op->value.get()});
}
void visit(const LetStmt *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->value.get(), op->body.get()}, {op->value.get()});
set_data_costs(op, 0, {op->value.get(), op->body.get()}, {op->value.get()});
}
void visit(const AssertStmt *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 1, {op->condition.get()});
set_data_costs(op, 0, {op->condition.get()});
}
void visit(const ProducerConsumer *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->body.get()}, {});
set_data_costs(op, 0, {op->body.get()}, {});
}
void visit(const For *op) override {
// The cost of a loop-node essentially depends on its iteration
// count. The cost model currently ignores such costs.
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->min.get(), op->extent.get(), op->body.get()}, {op->min.get(), op->extent.get()});
set_data_costs(op, 0, {op->min.get(), op->extent.get(), op->body.get()}, {op->min.get(), op->extent.get()});
}
void visit(const Acquire *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 1, {op->semaphore.get(), op->count.get(), op->body.get()}, {op->semaphore.get(), op->count.get()});
set_data_costs(op, 0, {op->semaphore.get(), op->count.get(), op->body.get()}, {op->semaphore.get(), op->count.get()});
}
void visit(const Store *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->predicate.get(), op->value.get(), op->index.get()});
set_data_costs(op, op->value.type().bits() * op->value.type().lanes(), {op->predicate.get(), op->value.get(), op->index.get()});
}
void visit(const Provide *op) override {
IRVisitor::visit(op);
std::vector<const IRNode *> args;
for (const auto &arg : op->values) {
args.push_back(arg.get());
}
for (const auto &arg : op->args) {
args.push_back(arg.get());
}
args.push_back(op->predicate.get());
set_compute_costs(op, 0, args, {});
set_data_costs(op, 0, args, {});
}
void visit(const Allocate *op) override {
// We do not model allocation/de-allocation costs
IRVisitor::visit(op);
std::vector<const IRNode *> args_inline;
for (const auto &arg : op->extents) {
args_inline.push_back(arg.get());
}
args_inline.push_back(op->new_expr.get());
std::vector<const IRNode *> args = args_inline;
args.push_back(op->body.get());
set_compute_costs(op, 0, args, args_inline);
set_data_costs(op, 0, args, args_inline);
}
void visit(const Free *op) override {
// We do not model allocation/de-allocation costs
IRVisitor::visit(op);
set_compute_costs(op, 0, {});
set_data_costs(op, 0, {});
}
void visit(const Realize *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->condition.get(), op->body.get()}, {op->condition.get()});
set_data_costs(op, 0, {op->condition.get(), op->body.get()}, {op->condition.get()});
}
void visit(const Prefetch *op) override {
IRVisitor::visit(op);
std::vector<const IRNode *> args_inline;
for (const auto &arg : op->bounds) {
args_inline.push_back(arg.min.get());
}
args_inline.push_back(op->condition.get());
std::vector<const IRNode *> args = args_inline;
args.push_back(op->body.get());
set_compute_costs(op, 0, args);
int elem_size = 0;
for (const auto &etype : op->types) {
elem_size += etype.bits() * etype.lanes();
}
set_data_costs(op, elem_size, args, args_inline);
}
void visit(const Block *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->first.get(), op->rest.get()}, {});
set_data_costs(op, 0, {op->first.get(), op->rest.get()}, {});
}
void visit(const Fork *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->first.get(), op->rest.get()}, {});
set_data_costs(op, 0, {op->first.get(), op->rest.get()}, {});
}
void visit(const IfThenElse *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 1, {op->condition.get(), op->then_case.get()}, {op->condition.get()});
set_data_costs(op, 0, {op->condition.get(), op->then_case.get()}, {op->condition.get()});
}
void visit(const Evaluate *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->value.get()});
set_data_costs(op, 0, {op->value.get()});
}
void visit(const Atomic *op) override {
IRVisitor::visit(op);
set_compute_costs(op, 0, {op->body.get()}, {});
set_data_costs(op, 0, {op->body.get()}, {});
}
};
/** GetAssemblyInfo
* Used to map some Halide IR nodes to line-numbers in the
* assembly file containing the corresponding generated code.
*/
class AssemblyInfo : public IRVisitor {
public:
AssemblyInfo() = default;
void generate(const std::string &code, const Module &m) {
// Traverse the module to populate the list of
// nodes we need to map and generate their assembly
// markers (comments that appear in the assembly code
// associating the code with this node)
for (const auto &fn : m.functions()) {
fn.body.accept(this);
}
// Find markers in asm code
std::istringstream asm_stream(code);
std::string line;
int lno = 1;
while (getline(asm_stream, line)) {
// Try all markers
std::vector<uint64_t> matched_nodes;
for (auto const &[node, regex] : markers) {
if (std::regex_search(line, regex)) {
// Save line number
lnos[node] = lno;
// Save this node's id
matched_nodes.push_back(node);
}
}
// We map to the first match, stop
// checking matched nodes
for (auto const &node : matched_nodes) {
markers.erase(node);
}
lno++;
}
}
int get_asm_lno(uint64_t node_id) {
if (lnos.count(node_id)) {
return lnos[node_id];
}
return -1;
}
private:
// Generate asm markers for Halide loops
int loop_id = 0;
int gen_loop_id() {
return ++loop_id;
}
std::regex gen_loop_asm_marker(int id, const std::string &loop_var) {
std::regex dollar("\\$");
std::string marker = "%\"" + std::to_string(id) + "_for_" + loop_var;
marker = std::regex_replace(marker, dollar, "\\$");
return std::regex(marker);
}
// Generate asm markers for Halide producer/consumer ndoes
int prodcons_id = 0;
int gen_prodcons_id() {
return ++prodcons_id;
}
std::regex gen_prodcons_asm_marker(int id, const std::string &var, bool is_producer) {
std::regex dollar("\\$");
std::string marker = "%\"" + std::to_string(id) + (is_producer ? "_produce_" : "_consume_") + var;
marker = std::regex_replace(marker, dollar, "\\$");
return std::regex(marker);
}
// Mapping of IR nodes to their asm markers
std::map<uint64_t, std::regex> markers;
// Mapping of IR nodes to their asm line numbers
std::map<uint64_t, int> lnos;
using IRVisitor::visit;
void visit(const ProducerConsumer *op) override {
// Generate asm marker
markers[(uint64_t)op] = gen_prodcons_asm_marker(gen_prodcons_id(), op->name, op->is_producer);
// Continue traversal
IRVisitor::visit(op);
}
void visit(const For *op) override {
// Generate asm marker
markers[(uint64_t)op] = gen_loop_asm_marker(gen_loop_id(), op->name);
// Continue traversal
IRVisitor::visit(op);
}
};
/** HTMLCodePrinter
* Prints IR code in HTML. Very similar to generating a stmt
* file, except that the generated html is more interactive.
*/
template<typename T>
class HTMLCodePrinter : public IRVisitor {
public:
HTMLCodePrinter(T &os, std::map<const IRNode *, int> &nids)
: stream(os), node_ids(nids), context_stack(1, 0) {
}
// Make class non-copyable and non-moveable
HTMLCodePrinter(const HTMLCodePrinter &) = delete;
HTMLCodePrinter &operator=(const HTMLCodePrinter &) = delete;
void init_cost_info(IRCostModel cm) {
cost_model = std::move(cm);
}
void print(const Module &m) {
// Generate a unique ID for this module
int id = gen_unique_id();
// Enter new scope for this module
scope.push(m.name(), id);
// The implementation doesn't need to support submodules:
// we only call this for Modules that have already had their submodules
// resolved.
internal_assert(m.submodules().empty()) << "StmtToViz does not support submodules.";
// Open div to hold this module
print_opening_tag("div", "Module");
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_opening_tag("span", "matched");
print_html_element("span", "keyword", "module");
print_text(" name=" + m.name() + ", target=" + m.target().to_string());
print_closing_tag("span");
print_toggle_anchor_closing_tag();
// Open code block to hold module body
print_html_element("span", "matched", " {");
// Open indented div to hold body code
print_opening_tag("div", "indent ModuleBody", id);
// Print module buffers
for (const auto &buf : m.buffers()) {
print(buf);
}
// Print module functions
for (const auto &fn : m.functions()) {
print(fn);
}
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding module body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this module
print_closing_tag("div");
// Pop out to outer scope
scope.pop(m.name());
}
private:
// Handle to output file stream
T &stream;
// Used to generate unique ids
int id = 0;
std::map<const IRNode *, int> &node_ids;
// Used to track scope during IR traversal
Scope<int> scope;
/* All spans and divs will have an id of the form "x-y", where x
* is shared among all spans/divs in the same context, and y is unique.
* These variables are used to track the context within generated HTML */
std::vector<int> context_stack;
std::vector<std::string> context_stack_tags;
// Holds cost information for visualized program
IRCostModel cost_model;
/* Private print functions to handle various IR types */
void print(const Buffer<> &buf) {
// Generate a unique ID for this module
int id = gen_unique_id();
// Determine whether to print buffer data
bool print_data = ends_with(buf.name(), "_gpu_source_kernels");
// Open div to hold this buffer
print_opening_tag("div", "Buffer");
if (print_data) {
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_html_element("span", "keyword", "buffer ");
print_variable(buf.name());
print_toggle_anchor_closing_tag();
// Print data
print_text(" = ");
// Open code block to hold module body
print_html_element("span", "matched", " {");
// Open indented div to hold buffer data
print_opening_tag("div", "indent BufferData", id);
std::string str((const char *)buf.data(), buf.size_in_bytes());
if (starts_with(buf.name(), "cuda_")) {
print_cuda_gpu_source_kernels(str);
} else {
stream << "<pre>\n"
<< str << "</pre>\n";
}
print_closing_tag("div");
// Close code block holding buffer body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), " }");
} else {
// Print buffer name and move on
print_html_element("span", "keyword", "buffer ");
print_variable(buf.name());
}
// Close div holding this buffer
print_closing_tag("div");
}
void print(const LoweredFunc &fn) {
// Generate a unique ID for this function
int id = gen_unique_id();
// Enter new scope for this function
scope.push(fn.name, id);
// Open div to hold this buffer
print_opening_tag("div", "Function");
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text (fn name and args)
// Note: We wrap the show/hide buttons in a navigation anchor
// that lets us sync text and visualization tabs.
print_opening_tag("span", "matched");
print_html_element("span", "keyword nav-anchor", "func ", "lowered-func-" + fn.name);
print_text(fn.name + "(");
print_closing_tag("span");
print_fndecl_args(fn.args);
print_html_element("span", "matched", ")");
print_toggle_anchor_closing_tag();
// Add a button to jump to this function in the viz
print_visualization_button("lowered-func-viz-" + std::to_string(id));
// Open code block to hold function body
print_html_element("span", "matched", "{");
// Open indented div to hold body code
print_opening_tag("div", "indent FunctionBody", id);
// Print function body
print(fn.body);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding func body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this function
print_closing_tag("div");
// Pop out to outer scope
scope.pop(fn.name);
}
void print(const Expr &ir) {
ir.accept(this);
}
void print(const Stmt &ir) {
ir.accept(this);
}
/* Methods used to emit common HTML patterns */
// Prints the opening tag for the specified html element. A unique ID is
// auto-generated unless provided.
void print_opening_tag(const std::string &tag, const std::string &cls, int id = -1) {
stream << "<" << tag << " class='" << cls << "' id='";
if (id == -1) {
stream << context_stack.back() << "-";
stream << gen_unique_id();
} else {
stream << id;
}
stream << "'>";
context_stack.push_back(gen_unique_id());
context_stack_tags.push_back(tag);
}
void print_opening_tag(const std::string &tag, const std::string &cls, std::string id) {
stream << "<" << tag << " class='" << cls << "' id='" << id << "'>";
context_stack.push_back(gen_unique_id());
context_stack_tags.push_back(tag);
}
// Prints the closing tag for the specified html element.
void print_closing_tag(const std::string &tag) {
internal_assert(!context_stack.empty() && tag == context_stack_tags.back())
<< tag << " " << context_stack.empty() << " " << context_stack_tags.back();
context_stack.pop_back();
context_stack_tags.pop_back();
stream << "</" + tag + ">";
}
// Prints an html element: opening tag, body and closing tag
void print_html_element(const std::string &tag, const std::string &cls, const std::string &body, int id = -1) {
print_opening_tag(tag, cls, id);
stream << body;
print_closing_tag(tag);
}
void print_html_element(const std::string &tag, const std::string &cls, const std::string &body, std::string id) {
print_opening_tag(tag, cls, id);
stream << body;
print_closing_tag(tag);
}
// Prints the opening/closing tags for an anchor that toggles code block view
void print_toggle_anchor_opening_tag(int id) {
stream << "<a onclick='return toggle(" << id << ");' href=_blank>";
}
void print_toggle_anchor_closing_tag() {
stream << "</a>";
}
// Prints a variable to stream
void print_variable(const std::string &x) {
stream << variable(x);
}
std::string variable(const std::string &x) {
int id;
if (scope.contains(x)) {
id = scope.get(x);
} else {
id = gen_unique_id();
scope.push(x, id);
}
std::ostringstream s;
s << "<b class='variable matched' id='" << id << "-" << gen_unique_id() << "'>";
s << x;
s << "</b>";
return s.str();
}
// Prints text to stream
void print_text(const std::string &x) {
stream << x;
}
// Prints the button to show or hide a code scope
void print_show_hide_icon(int id) {
stream << "<div class='show-hide-btn-wrapper'>"
<< " <div class='show-hide-btn' style='display:none;' id=" << id << "-show>"
<< " <i class='bi bi-plus-square' title='Expand code block'></i>"
<< " </div>"
<< " <div class='show-hide-btn' id=" << id << "-hide>"
<< " <i class='bi bi-dash-square' title='Collapse code block'></i>"
<< " </div>"
<< "</div>";
}
// Prints a button to sync text with visualization
void print_visualization_button(std::string id) {
stream << "<button class='icon-btn sync-btn' onclick='scrollToViz(\"" << id << "\")'>"
<< " <i class='bi bi-arrow-right-square' title='Jump to visualization'></i>"
<< "</button>";
}
// CUDA kernels are embedded into modules as PTX assembly. This
// routine pretty - prints that assembly format.
void print_cuda_gpu_source_kernels(const std::string &str) {
print_opening_tag("code", "ptx");
int current_id = -1;
bool in_braces = false;
bool in_func_signature = false;
std::string current_kernel;
std::istringstream ss(str);
for (std::string line; std::getline(ss, line);) {
if (line.empty()) {
stream << "\n";
continue;
}
line = replace_all(line, "&", "&");
line = replace_all(line, "<", "<");
line = replace_all(line, ">", ">");
line = replace_all(line, "\"", """);
line = replace_all(line, "/", "/");
line = replace_all(line, "'", "'");
if (starts_with(line, ".visible .entry")) {
std::vector<std::string> parts = split_string(line, " ");
if (parts.size() == 3) {
in_func_signature = true;
current_id = gen_unique_id();
print_toggle_anchor_opening_tag(current_id);
print_show_hide_icon(current_id);
std::string kernel_name = parts[2].substr(0, parts[2].length() - 1);
line = "<span class='keyword'>.visible</span> <span class='keyword'>.entry</span> ";
line += variable(kernel_name) + " <span class='matched'>(</span>";
current_kernel = kernel_name;
}
} else if (starts_with(line, ")") && in_func_signature) {
print_toggle_anchor_closing_tag();
in_func_signature = false;
line = "<span class='matched'>)</span>" + line.substr(1);
} else if (starts_with(line, "{") && !in_braces) {
in_braces = true;
print_toggle_anchor_closing_tag();
print_html_element("span", "matched", "{");
internal_assert(current_id != -1);
print_opening_tag("div", "indent", current_id);
current_id = -1;
line = line.substr(1);
scope.push(current_kernel, gen_unique_id());
} else if (starts_with(line, "}") && in_braces) {
print_closing_tag("div");
line = "<span class='matched'>}</span>" + line.substr(1);
in_braces = false;
scope.pop(current_kernel);
}
bool indent = false;
if (line[0] == '\t') {
// Replace first tab with four spaces.
line = line.substr(1);
indent = true;
}
line = replace_all(line, ".f32", ".<span class='OpF32'>f32</span>");
line = replace_all(line, ".f64", ".<span class='OpF64'>f64</span>");
line = replace_all(line, ".s8", ".<span class='OpI8'>s8</span>");
line = replace_all(line, ".s16", ".<span class='OpI16'>s16</span>");
line = replace_all(line, ".s32", ".<span class='OpI32'>s32</span>");
line = replace_all(line, ".s64", ".<span class='OpI64'>s64</span>");
line = replace_all(line, ".u8", ".<span class='OpI8'>u8</span>");
line = replace_all(line, ".u16", ".<span class='OpI16'>u16</span>");
line = replace_all(line, ".u32", ".<span class='OpI32'>u32</span>");
line = replace_all(line, ".u64", ".<span class='OpI64'>u64</span>");
line = replace_all(line, ".b8", ".<span class='OpB8'>b8</span>");
line = replace_all(line, ".b16", ".<span class='OpB16'>b16</span>");
line = replace_all(line, ".b32", ".<span class='OpB32'>b32</span>");
line = replace_all(line, ".b64", ".<span class='OpB64'>b64</span>");
line = replace_all(line, ".v2", ".<span class='OpVec2'>v2</span>");
line = replace_all(line, ".v4", ".<span class='OpVec4'>v4</span>");
line = replace_all(line, "ld.", "<span class='Memory'>ld</span>.");
line = replace_all(line, "st.", "<span class='Memory'>st</span>.");
size_t idx;
if ((idx = line.find("//")) != std::string::npos) {
line.insert(idx, "<span class='Comment'>");
line += "</span>";
}
// Predicated instructions
if (line.front() == '@' && indent) {
idx = line.find(' ');
std::string pred = line.substr(1, idx - 1);
line = "<span class='Pred'>@" + variable(pred) + "</span>" + line.substr(idx);
}
// Labels
if (line.front() == 'L' && !indent && (idx = line.find(':')) != std::string::npos) {
std::string label = line.substr(0, idx);
line = "<span class='Label'>" + variable(label) + "</span>:" + line.substr(idx + 1);
}
// Highlight operands
if ((idx = line.find(" \t")) != std::string::npos && line.back() == ';') {
std::string operands_str = line.substr(idx + 2);
operands_str = operands_str.substr(0, operands_str.length() - 1);
std::vector<std::string> operands = split_string(operands_str, ", ");
operands_str = "";
for (size_t opidx = 0; opidx < operands.size(); ++opidx) {
std::string op = operands[opidx];
internal_assert(!op.empty());
if (opidx != 0) {
operands_str += ", ";
}
if (op.back() == '}') {
std::string reg = op.substr(0, op.size() - 1);
operands_str += variable(reg) + '}';
} else if (op.front() == '%') {
operands_str += variable(op);
} else if (op.find_first_not_of("-0123456789") == std::string::npos) {
operands_str += "<span class='IntImm Imm'>";
operands_str += op;
operands_str += "</span>";
} else if (starts_with(op, "0f") &&
op.find_first_not_of("0123456789ABCDEF", 2) == std::string::npos) {
operands_str += "<span class='FloatImm Imm'>";
operands_str += op;
operands_str += "</span>";
} else if (op.front() == '[' && op.back() == ']') {
size_t idx = op.find('+');
if (idx == std::string::npos) {
std::string reg = op.substr(1, op.size() - 2);
operands_str += '[' + variable(reg) + ']';
} else {
std::string reg = op.substr(1, idx - 1);
std::string offset = op.substr(idx + 1);
offset = offset.substr(0, offset.size() - 1);
operands_str += '[' + variable(reg) + "+";
operands_str += "<span class='IntImm Imm'>";
operands_str += offset;
operands_str += "</span>";
operands_str += ']';
}
} else if (op.front() == '{') {
std::string reg = op.substr(1);
operands_str += '{' + variable(reg);
} else if (op.front() == 'L') {
// Labels
operands_str += "<span class='Label'>" + variable(op) + "</span>";
} else {
operands_str += op;
}
}
operands_str += ";";
line = line.substr(0, idx + 2) + operands_str;
}
if (indent) {
stream << " ";
}
stream << line << "\n";
}
print_closing_tag("code");
}
// Prints the args in a function declaration
void print_fndecl_args(const std::vector<LoweredArgument> &args) {
bool print_delim = false;
for (const auto &arg : args) {
if (print_delim) {
print_html_element("span", "matched", ",");
print_text(" ");
}
print_variable(arg.name);
print_delim = true;
}
}
/* Helper functions for printing IR nodes */
void print_constant(std::string cls, Expr c) {
print_opening_tag("span", cls);
stream << c;
print_closing_tag("span");
}
void print_type(Type t) {
print_opening_tag("span", "Type");
stream << t;
print_closing_tag("span");
}
void print_binary_op(const Expr &a, const Expr &b, std::string op) {
print_opening_tag("span", "BinaryOp");
print_html_element("span", "matched", "(");
print(a);
print_text(" ");
print_html_element("span", "matched Operator", op);
print_text(" ");
print(b);
print_html_element("span", "matched", ")");
print_closing_tag("span");
}
void print_function_call(std::string fn_name, const std::vector<Expr> &args, int id) {
print_opening_tag("span", "nav-anchor", "fn-call-" + std::to_string(id));
print_function_call(fn_name, args);
print_closing_tag("span");
}
void print_function_call(std::string fn_name, const std::vector<Expr> &args) {
print_opening_tag("span", "matched");
print_html_element("span", "Symbol matched", fn_name);
print_text("(");
print_closing_tag("span");
bool print_delim = false;
for (const auto &arg : args) {
if (print_delim) {
print_html_element("span", "matched", ", ");
}
print(arg);
print_delim = true;
}
print_html_element("span", "matched", ")");
}
// To avoid generating ridiculously deep DOMs, we flatten blocks here.
void print_block_stmt(const Stmt &stmt) {
if (const Block *b = stmt.as<Block>()) {
print_block_stmt(b->first);
print_block_stmt(b->rest);
} else if (stmt.defined()) {
print(stmt);
}
}
// We also flatten forks
void visit_fork_stmt(const Stmt &stmt) {
if (const Fork *f = stmt.as<Fork>()) {
visit_fork_stmt(f->first);
visit_fork_stmt(f->rest);
} else if (stmt.defined()) {
// Give task a unique id
int id = gen_unique_id();
// Start a dive to hold code for this task
print_opening_tag("div", "ForkTask");
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_html_element("span", "keyword matched", "task");
print_toggle_anchor_closing_tag();
// Open code block to hold task body
print_html_element("span", "matched", " {");
// Open indented div to hold body code
print_opening_tag("div", "indent ForkTask", id);
// Print task body
print(stmt);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding task body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this fork task
print_closing_tag("div");
}
}
// Prints compute and data cost buttons/indicators
void print_cost_buttons(const IRNode *op) {
int id = gen_node_id(op);
print_cost_buttons(op, id);
}
void print_cost_buttons(const IRNode *op, int id) {
print_opening_tag("div", "node-cost");
print_compute_cost(op, id);
print_data_movement_cost(op, id);
print_closing_tag("div");
}
// Prints the button/indicator for the compute cost of a line in the program
void print_compute_cost(const IRNode *op, int id) {
int max_line_cost = cost_model.get_max_compute_cost(false);
int line_cost = cost_model.get_compute_cost(op, false);
int block_cost = cost_model.get_compute_cost(op, true);
std::string _id = "cc-" + std::to_string(id);
print_cost_btn(line_cost, block_cost, max_line_cost, _id, "Op Count: ");
}
// Prints the button/indicator for the data movement cost of a line in the program
void print_data_movement_cost(const IRNode *op, int id) {
int max_line_cost = cost_model.get_max_data_movement_cost(false);
int line_cost = cost_model.get_data_movement_cost(op, false);
int block_cost = cost_model.get_data_movement_cost(op, true);
std::string _id = "dc-" + std::to_string(id);
print_cost_btn(line_cost, block_cost, max_line_cost, _id, "Bits Moved: ");
}
// Prints a cost button/indicator
void print_cost_btn(int line_cost, int block_cost, int max_line_cost, std::string id, std::string prefix) {
const int num_cost_buckets = 20;
int line_cost_bin_size = (max_line_cost / num_cost_buckets) + 1;
int block_cost_bin_size = (max_line_cost / num_cost_buckets) + 1;
int line_costc = line_cost / line_cost_bin_size;
int block_costc = block_cost / block_cost_bin_size;
if (line_costc >= num_cost_buckets) {
line_costc = num_cost_buckets - 1;
}
if (block_costc >= num_cost_buckets) {
block_costc = num_cost_buckets - 1;
}
stream << "<div id='" << id << "' class='cost-btn CostColor" << line_costc << "'"
<< " aria-describedby='tooltip-" << id << "'"
<< " line-cost='" << line_cost << "' block-cost='" << block_cost << "'"
<< " line-cost-color='" << line_costc << "' block-cost-color='" << block_costc << "'>"
<< "</div>";
stream << "<span id='tooltip-" << id << "' class='tooltip cond-tooltop' role='tooltip-" << id << "'>"
<< prefix << line_cost
<< "</span>";
}
/* Misc utility methods */
int gen_unique_id() {
return id++;
}
int gen_node_id(const IRNode *node) {
if (node_ids.count(node) == 0) {
node_ids[node] = gen_unique_id();
}
return node_ids[node];
}
/* All visitor functions inherited from IRVisitor */
void visit(const IntImm *op) override {
print_constant("IntImm Imm", Expr(op));
}
void visit(const UIntImm *op) override {
print_constant("UIntImm Imm", Expr(op));
}
void visit(const FloatImm *op) override {
print_constant("FloatImm Imm", Expr(op));
}
void visit(const StringImm *op) override {
print_constant("StringImm Imm", Expr(op));
}
void visit(const Variable *op) override {
print_variable(op->name);
}
void visit(const Cast *op) override {
print_opening_tag("span", "Cast");
print_opening_tag("span", "matched");
print_type(op->type);
print_text("(");
print_closing_tag("span");
print(op->value);
print_html_element("span", "matched", ")");
print_closing_tag("span");
}
void visit(const Reinterpret *op) override {
print_opening_tag("span", "Reinterpret");
print_opening_tag("span", "matched");
print_type(op->type);
print_text("(");
print_closing_tag("span");
print(op->value);
print_html_element("span", "matched", ")");
print_closing_tag("span");
}
void visit(const Add *op) override {
print_binary_op(op->a, op->b, "+");
}
void visit(const Sub *op) override {
print_binary_op(op->a, op->b, "-");
}
void visit(const Mul *op) override {
print_binary_op(op->a, op->b, "*");
}
void visit(const Div *op) override {
print_binary_op(op->a, op->b, "/");
}
void visit(const Mod *op) override {
print_binary_op(op->a, op->b, "%");
}
void visit(const Min *op) override {
print_opening_tag("span", "Min");
print_function_call("min", {op->a, op->b});
print_closing_tag("span");
}
void visit(const Max *op) override {
print_opening_tag("span", "Max");
print_function_call("max", {op->a, op->b});
print_closing_tag("span");
}
void visit(const EQ *op) override {
print_binary_op(op->a, op->b, "==");
}
void visit(const NE *op) override {
print_binary_op(op->a, op->b, "!=");
}
void visit(const LT *op) override {
print_binary_op(op->a, op->b, "<");
}
void visit(const LE *op) override {
print_binary_op(op->a, op->b, "<=");
}
void visit(const GT *op) override {
print_binary_op(op->a, op->b, ">");
}
void visit(const GE *op) override {
print_binary_op(op->a, op->b, ">=");
}
void visit(const And *op) override {
print_binary_op(op->a, op->b, "&&");
}
void visit(const Or *op) override {
print_binary_op(op->a, op->b, "||");
}
void visit(const Not *op) override {
print_opening_tag("span", "Not");
print_text("!");
print(op->a);
print_closing_tag("span");
}
void visit(const Select *op) override {
print_opening_tag("span", "Select");
print_function_call("select", {op->condition, op->true_value, op->false_value});
print_closing_tag("span");
}
void visit(const Load *op) override {
int id = gen_node_id(op);
print_opening_tag("span", "Load nav-anchor", "load-" + std::to_string(id));
print_opening_tag("span", "matched");
print_variable(op->name);
print_text("[");
print_closing_tag("span");
print(op->index);
print_html_element("span", "matched", "]");
if (!is_const_one(op->predicate)) {
print_html_element("span", "keyword", " if ");
print(op->predicate);
}
print_closing_tag("span");
}
void visit(const Ramp *op) override {
print_opening_tag("span", "Ramp");
print_function_call("ramp", {op->base, op->stride, Expr(op->lanes)});
print_closing_tag("span");
}
void visit(const Broadcast *op) override {
print_opening_tag("span", "Broadcast");
print_opening_tag("span", "matched");
print_text("x" + std::to_string(op->lanes) + "(");
print_closing_tag("span");
print(op->value);
print_html_element("span", "matched", ")");
print_closing_tag("span");
}
void visit(const Call *op) override {
int id = gen_node_id(op);
print_opening_tag("span", "Call");
print_function_call(op->name, op->args, id);
print_closing_tag("span");
}
void visit(const Let *op) override {
scope.push(op->name, gen_unique_id());
print_opening_tag("span", "Let");
print_opening_tag("span", "matched");
print_text("(");
print_html_element("span", "keyword", "let ");
print_variable(op->name);
print_html_element("span", "Operator Assign", " = ");
print_closing_tag("span");
print(op->value);
print_html_element("span", "matched keyword", " in ");
print(op->body);
print_html_element("span", "matched", ")");
print_closing_tag("span");
scope.pop(op->name);
}
void visit(const LetStmt *op) override {
int id = gen_node_id(op);
scope.push(op->name, gen_unique_id());
print_opening_tag("div", "LetStmt");
print_cost_buttons(op);
print_opening_tag("p", "WrapLine");
print_opening_tag("span", "cost-highlight", "cost-bg-" + std::to_string(id));
print_opening_tag("span", "matched");
print_html_element("span", "keyword", "let ");
print_variable(op->name);
print_html_element("span", "Operator Assign", " = ");
print_closing_tag("span");
print(op->value);
print_closing_tag("span");
print_closing_tag("p");
print(op->body);
print_closing_tag("div");
scope.pop(op->name);
}
void visit(const AssertStmt *op) override {
print_opening_tag("div", "AssertStmt WrapLine");
print_cost_buttons(op);
print_function_call("assert", {op->condition, op->message});
print_closing_tag("div");
}
void visit(const ProducerConsumer *op) override {
// Give this Producer/Consumer a unique id
int id = gen_node_id(op);
// Push a new scope
scope.push(op->name, id);
// Start a dive to hold code for this Producer/Consumer
print_opening_tag("div", op->is_producer ? "Produce" : "Consumer");
// Print cost buttons
print_cost_buttons(op, id);
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_opening_tag("span", "matched");
print_html_element("span", "keyword nav-anchor", op->is_producer ? "produce " : "consume ",
"prodcons-" + std::to_string(id));
print_variable(op->name);
print_closing_tag("span");
print_toggle_anchor_closing_tag();
// Add a button to jump to this producer/consumer in the viz
print_visualization_button("prodcons-viz-" + std::to_string(id));
// Open code block to hold function body
print_html_element("span", "matched", "{");
// Open indented div to hold body code
print_opening_tag("div", "indent ProducerConsumerBody", id);
// Print producer/consumer body
print(op->body);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding producer/consumer body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this producer/consumer
print_closing_tag("div");
// Pop out of loop scope
scope.pop(op->name);
}
void visit(const For *op) override {
// Give this loop a unique id
int id = gen_node_id(op);
// Push scope
scope.push(op->name, id);
// Start a dive to hold code for this allocate
print_opening_tag("div", "For");
// Print cost buttons
print_cost_buttons(op, id);
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_opening_tag("span", "matched");
print_opening_tag("span", "keyword nav-anchor", "loop-" + std::to_string(id));
stream << op->for_type << op->device_api;
print_closing_tag("span");
print_text(" (");
print_closing_tag("span");
print_variable(op->name);
print_html_element("span", "matched", ", ");
print(op->min);
print_html_element("span", "matched", ", ");
print(op->extent);
print_html_element("span", "matched", ")");
print_toggle_anchor_closing_tag();
// Add a button to jump to this loop in the viz
print_visualization_button("loop-viz-" + std::to_string(id));
// Open code block to hold function body
print_html_element("span", "matched", "{");
// Open indented div to hold body code
print_opening_tag("div", "indent ForBody", id);
// Print loop body
print(op->body);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding loop body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this for loop
print_closing_tag("div");
// Pop out of loop scope
scope.pop(op->name);
}
void visit(const Acquire *op) override {
// Give this acquire a unique id
int id = gen_node_id(op);
// Start a dive to hold code for this acquire
print_opening_tag("div", "Acquire");
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_opening_tag("span", "matched");
print_html_element("span", "keyword", "acquire", "acquire-" + std::to_string(id));
print_text(" (");
print_closing_tag("span");
print(op->semaphore);
print_html_element("span", "matched", ", ");
print(op->count);
print_html_element("span", "matched", ")");
print_toggle_anchor_closing_tag();
// Add a button to jump to this acquire in the viz
print_visualization_button("acquire-viz-" + std::to_string(id));
// Open code block to hold function body
print_html_element("span", "matched", "{");
// Open indented div to hold body code
print_opening_tag("div", "indent AcquireBody", id);
// Print aquire body
print(op->body);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding acquire body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this acquire
print_closing_tag("div");
}
void visit(const Store *op) override {
int id = gen_node_id(op);
// Start a dive to hold code for this acquire
print_opening_tag("div", "Store WrapLine");
// Print cost buttons
print_cost_buttons(op);
// Print store target
print_opening_tag("span", "matched");
print_opening_tag("span", "nav-anchor", "store-" + std::to_string(id));
print_variable(op->name);
print_text("[");
print_closing_tag("span");
print_closing_tag("span");
print(op->index);
print_html_element("span", "matched", "]");
print_html_element("span", "Operator Assign Matched", " = ");
// Print store value
print_opening_tag("span", "StoreValue");
print(op->value);
if (!is_const_one(op->predicate)) {
print_html_element("span", "keyword", " if ");
print(op->predicate);
}
print_closing_tag("span");
// Close div holding this store
print_closing_tag("div");
}
void visit(const Provide *op) override {
print_opening_tag("div", "Provide WrapLine");
print_function_call(op->name, op->args);
if (op->values.size() > 1) {
print_html_element("span", "matched", " = {");
for (size_t i = 0; i < op->args.size(); i++) {
if (i > 0) {
print_html_element("span", "matched", ", ");
}
print(op->args[i]);
}
print_html_element("span", "matched", "}");
} else {
print(op->values[0]);
}
print_closing_tag("div");
}
void visit(const Allocate *op) override {
int id = gen_node_id(op);
// Push scope
scope.push(op->name, gen_unique_id());
// Start a dive to hold code for this allocate
print_opening_tag("div", "Allocate");
// Print cost buttons
print_cost_buttons(op);
// Print allocation name, type and extents
print_opening_tag("span", "matched");
print_html_element("span", "keyword nav-anchor", "allocate ", "allocate-" + std::to_string(id));
print_variable(op->name);
print_text("[");
print_closing_tag("span");
print_type(op->type);
for (const auto &extent : op->extents) {
print_text(" * ");
print(extent);
}
print_html_element("span", "matched", "]");
// Print memory type
if (op->memory_type != MemoryType::Auto) {
print_html_element("span", "keyword", " in ");
stream << op->memory_type;
}
// Print allocation condition
if (!is_const_one(op->condition)) {
print_html_element("span", "keyword", " if ");
print(op->condition);
}
// Print custom new and free expressions
if (op->new_expr.defined()) {
print_opening_tag("span", "matched");
print_html_element("span", "keyword", "custom_new");
print_text(" {");
print_closing_tag("span");
print(op->new_expr);
print_html_element("span", "matched", "}");
}
if (!op->free_function.empty()) {
print_opening_tag("span", "matched");
print_html_element("span", "keyword", "custom_free");
print_text(" {");
print_closing_tag("span");
print_text(" " + op->free_function + "(); ");
print_html_element("span", "matched", "}");
}
// Add a button to jump to this allocation in the viz
print_visualization_button("allocate-viz-" + std::to_string(id));
// Print allocation body
print_opening_tag("div", "AllocateBody");
print(op->body);
print_closing_tag("div");
// Close dive holding the allocate
print_closing_tag("div");
// Pop out of allocate scope
scope.pop(op->name);
}
void visit(const Free *op) override {
print_opening_tag("div", "Free WrapLine");
print_cost_buttons(op);
print_html_element("span", "keyword", "free ");
print_variable(op->name);
print_closing_tag("div");
}
void visit(const Realize *op) override {
// Give this acquire a unique id
int id = gen_node_id(op);
// Push scope
scope.push(op->name, id);
// Start a dive to hold code for this realize
print_opening_tag("div", "Realize");
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_opening_tag("span", "matched");
print_html_element("span", "keyword", "realize", "realize-" + std::to_string(id));
print_variable(op->name);
print_text(" (");
for (size_t i = 0; i < op->bounds.size(); i++) {
print_html_element("span", "matched", "[");
print(op->bounds[i].min);
print_html_element("span", "matched", ", ");
print(op->bounds[i].extent);
print_html_element("span", "matched", "]");
if (i < op->bounds.size() - 1) {
print_html_element("span", "matched", ", ");
}
}
print_html_element("span", "matched", ")");
// Print predicate
if (!is_const_one(op->condition)) {
print_html_element("span", "keyword", " if ");
print(op->condition);
}
print_toggle_anchor_closing_tag();
// Add a button to jump to this realize in the viz
print_visualization_button("realize-viz-" + std::to_string(id));
// Open code block to hold function body
print_html_element("span", "matched", " {");
// Open indented div to hold body code
print_opening_tag("div", "indent RealizeBody", id);
// Print realization body
print(op->body);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding realize body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this realize
print_closing_tag("div");
// Pop out of this scope
scope.pop(op->name);
}
void visit(const Block *op) override {
print_opening_tag("div", "Block");
print_block_stmt(op->first);
print_block_stmt(op->rest);
print_closing_tag("div");
}
void visit(const Fork *op) override {
// Give this acquire a unique id
int id = gen_node_id(op);
// Start a dive to hold code for this realize
print_opening_tag("div", "Fork");
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_html_element("span", "keyword matched", "fork");
print_toggle_anchor_closing_tag();
// Open code block to hold fork body
print_html_element("span", "matched", " {");
// Open indented div to hold body code
print_opening_tag("div", "indent ForkBody", id);
// Print fork body
visit_fork_stmt(op->first);
visit_fork_stmt(op->rest);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding fork body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this fork
print_closing_tag("div");
}
void visit(const IfThenElse *op) override {
// Give this acquire a unique id
int then_block_id = gen_unique_id();
int then_node_id = gen_node_id(op->then_case.get());
// Start a dive to hold code for this conditional
print_opening_tag("div", "IfThenElse");
// Print cost buttons
print_cost_buttons(op, then_block_id);
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(then_block_id);
// -- print icon
print_show_hide_icon(then_block_id);
// -- print text
print_opening_tag("span", "matched");
print_html_element("span", "keyword nav-anchor IfSpan", "if", "cond-" + std::to_string(then_node_id));
print_text(" (");
print_closing_tag("span");
print(op->condition);
print_html_element("span", "matched", ")");
print_toggle_anchor_closing_tag();
// Add a button to jump to this conditional in the viz
print_visualization_button("cond-viz-" + std::to_string(then_node_id));
// Flatten nested if's in the else case as an
// `if-then-else_if-else` sequence
while (true) {
/* Handle the `then` case */
// Open code block to hold `then` case
print_html_element("span", "matched", " {");
// Open indented div to hold code for the `then` case
print_opening_tag("div", "indent ThenBody", then_block_id);
// Print then case body
print(op->then_case);
// Close indented div holding `then` case
print_closing_tag("div");
// Close code block holding `then` case
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(then_block_id), "}");
// If there is no `else` case, we are done!
if (!op->else_case.defined()) {
break;
}
/* Handle the `else` case */
// If the else-case is another if-then-else, flatten it
if (const IfThenElse *nested_if = op->else_case.as<IfThenElse>()) {
// Generate a new id for the `else-if` case
then_block_id = gen_unique_id();
then_node_id = gen_node_id(nested_if->then_case.get());
// Print cost buttons
print_cost_buttons(op, then_block_id);
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(then_block_id);
// -- print icon
print_show_hide_icon(then_block_id);
// -- print text
print_opening_tag("span", "matched");
print_html_element("span", "keyword nav-anchor IfSpan", "else if", "cond-" + std::to_string(then_node_id));
print_text(" (");
print_closing_tag("span");
print(nested_if->condition);
print_html_element("span", "matched", ")");
print_toggle_anchor_closing_tag();
// Add a button to jump to this conditional branch in the viz
print_visualization_button("cond-viz-" + std::to_string(then_node_id));
// Update op to the nested if for next loop iteration
op = nested_if;
} else { // Otherwise, print it and we are done!
int else_block_id = gen_unique_id();
int else_node_id = gen_node_id(op->else_case.get());
// Print cost buttons
print_cost_buttons(op, else_block_id);
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(else_block_id);
// -- print icon
print_show_hide_icon(else_block_id);
// -- print text
print_opening_tag("span", "matched");
print_html_element("span", "keyword nav-anchor IfSpan", "else", "cond-" + std::to_string(else_node_id));
print_closing_tag("span");
print_toggle_anchor_closing_tag();
// Add a button to jump to this conditional branch in the viz
print_visualization_button("cond-viz-" + std::to_string(else_node_id));
// Open code block to hold `else` case
print_html_element("span", "matched", " {");
// Open indented div to hold code for the `then` case
print_opening_tag("div", "indent ElseBody", else_block_id);
// Print `else` case body
print(op->else_case);
// Close indented div holding `else` case
print_closing_tag("div");
// Close code block holding `else` case
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(else_block_id), "}");
break;
}
}
// Close div holding the conditional
print_closing_tag("div");
}
void visit(const Evaluate *op) override {
print_opening_tag("div", "Block");
// Print cost buttons
print_cost_buttons(op);
print(op->value);
print_closing_tag("div");
}
void visit(const Shuffle *op) override {
print_opening_tag("div", "Block");
if (op->is_concat()) {
print_function_call("concat_vectors", op->vectors);
} else if (op->is_interleave()) {
print_function_call("interleave_vectors", op->vectors);
} else if (op->is_extract_element()) {
std::vector<Expr> args = op->vectors;
args.emplace_back(op->slice_begin());
print_function_call("extract_element", args);
} else if (op->is_slice()) {
std::vector<Expr> args = op->vectors;
args.emplace_back(op->slice_begin());
args.emplace_back(op->slice_stride());
args.emplace_back(static_cast<int>(op->indices.size()));
print_function_call("slice_vectors", args);
} else {
std::vector<Expr> args = op->vectors;
for (int i : op->indices) {
args.emplace_back(i);
}
print_function_call("shuffle", args);
}
print_closing_tag("div");
}
void visit(const VectorReduce *op) override {
print_opening_tag("div", "VectorReduce");
print_type(op->type);
print_function_call("vector_reduce", {op->op, op->value});
print_closing_tag("div");
}
void visit(const Prefetch *op) override {
print_opening_tag("div", "Prefetch");
// Print cost buttons
print_cost_buttons(op);
// Print prefetch
print_html_element("span", "matched keyword", "prefetch ");
print_variable(op->name);
print_html_element("span", "matched", "(");
for (size_t i = 0; i < op->bounds.size(); i++) {
print_html_element("span", "matched", "[");
print(op->bounds[i].min);
print_html_element("span", "matched", ",");
print(op->bounds[i].extent);
print_html_element("span", "matched", "]");
if (i < op->bounds.size() - 1) {
print_html_element("span", "matched", ", ");
}
}
print_html_element("span", "matched", ")");
// Print condition
if (!is_const_one(op->condition)) {
print_html_element("span", "keyword", " if ");
print(op->condition);
}
// Print prefetch body
print_opening_tag("div", "indent PrefetchBody");
print(op->body);
print_closing_tag("div");
print_closing_tag("div");
}
void visit(const Atomic *op) override {
// Give this node a unique id
int id = gen_unique_id();
// Start a dive to hold code for this atomic
print_opening_tag("div", "Atomic");
// Generate the show hide icon/text buttons
print_toggle_anchor_opening_tag(id);
// -- print icon
print_show_hide_icon(id);
// -- print text
print_html_element("span", "matched keyword", "atomic");
if (!op->mutex_name.empty()) {
print_html_element("span", "matched", "(");
print_html_element("span", "Symbol", op->mutex_name);
print_html_element("span", "matched", ")");
}
print_toggle_anchor_closing_tag();
// Open code block to hold atomic body
print_html_element("span", "matched", " {");
// Open indented div to hold atomic code
print_opening_tag("div", "indent AtomicBody", id);
// Print atomic body
print(op->body);
// Close indented div holding body code
print_closing_tag("div");
// Close code block holding fork body
print_html_element("span", "matched ClosingBrace cb-" + std::to_string(id), "}");
// Close div holding this atomic
print_closing_tag("div");
}
};
/** HTMLVisualizationPrinter
* Visualizes the IR in HTML. The visualization is essentially
* an abstracted version of the code, highlighting the higher
* level execution pipeline along with key properties of the
* computation performed at each stage.
*/
class HTMLVisualizationPrinter : public IRVisitor {
public:
HTMLVisualizationPrinter(std::ofstream &os, std::map<const IRNode *, int> &nids)
: stream(os), node_ids(nids) {
}
// Make class non-copyable and non-moveable
HTMLVisualizationPrinter(const HTMLVisualizationPrinter &) = delete;
HTMLVisualizationPrinter &operator=(const HTMLVisualizationPrinter &) = delete;
void init_cost_info(IRCostModel cm) {
cost_model = std::move(cm);
}
void print(const Module &m, AssemblyInfo asm_info) {
assembly_info = std::move(asm_info);
for (const auto &fn : m.functions()) {
print(fn);
}
}
private:
// Handle to output file stream
std::ofstream &stream;
// Used to track the context within generated HTML
std::vector<std::string> context_stack_tags;
// Assembly line number info
AssemblyInfo assembly_info;
// Holds cost information for visualized program
IRCostModel cost_model;
// Generate unique ids
int id = 0;
std::map<const IRNode *, int> &node_ids;
int gen_unique_id() {
return id++;
}
int gen_node_id(const IRNode *node) {
if (node_ids.count(node) == 0) {
node_ids[node] = gen_unique_id();
}
return node_ids[node];
}
/* Private print functions to handle various IR types */
void print(const LoweredFunc &fn) {
int id = gen_unique_id();
// Start a div to hold the function viz
print_opening_tag("div", "center fn-wrapper");
// Create the header bar
print_opening_tag("div", "fn-header");
print_collapse_expand_btn(id);
print_code_button("lowered-func-" + fn.name);
print_html_element("span", "fn-title", "Func: " + fn.name, "lowered-func-viz-" + fn.name);
print_closing_tag("div");
// Print function body
print_opening_tag("div", "fn-body", "viz-" + std::to_string(id));
fn.body.accept(this);
print_closing_tag("div");
// Close function div
print_closing_tag("div");
}
/* Methods used to emit common HTML patterns */
// Prints the opening tag for the specified html element.
void print_opening_tag(const std::string &tag, const std::string &cls) {
stream << "<" << tag << " class='" << cls << "'>";
context_stack_tags.push_back(tag);
}
void print_opening_tag(const std::string &tag, const std::string &cls, const std::string &id) {
stream << "<" << tag << " class='" << cls << "' id='" << id << "'>";
context_stack_tags.push_back(tag);
}
// Prints the closing tag for the specified html element.
void print_closing_tag(const std::string &tag) {
internal_assert(tag == context_stack_tags.back());
context_stack_tags.pop_back();
stream << "</" + tag + ">";
}
// Prints an html element: opening tag, body and closing tag
void print_html_element(const std::string &tag, const std::string &cls, const std::string &body) {
print_opening_tag(tag, cls);
stream << body;
print_closing_tag(tag);
}
void print_html_element(const std::string &tag, const std::string &cls, const std::string &body, const std::string &id) {
print_opening_tag(tag, cls, id);
stream << body;
print_closing_tag(tag);
}
// Prints text to stream
void print_text(const std::string &x) {
stream << x;
}
// Prints a button to sync visualization with code
void print_code_button(const std::string &id) {
stream << "<button class='icon-btn sync-btn' onclick='scrollToCode(\"" << id << "\")'>"
<< " <i class='bi bi-arrow-left-square' title='Jump to code'></i>"
<< "</button>";
}
// Prints a button to sync visualization with assembly
void print_asm_button(const std::string &id) {
stream << "<button class='icon-btn sync-btn' onclick='scrollToAsm(\"" << id << "\")'>"
<< " <i class='bi bi-arrow-right-square' title='Jump to assembly'></i>"
<< "</button>";
}
// Prints a function-call box
void print_fn_button(const std::string &name, int id) {
print_opening_tag("div", "fn-call");
print_code_button("fn-call-" + std::to_string(id));
print_text(get_as_var(name) + "(...)");
print_closing_tag("div");
}
// Prints a button to collapse or expand a visualization box
void print_collapse_expand_btn(int id) {
stream << "<button class='icon-btn' id='viz-" << id << "-hide' onclick='return toggleViz(\"viz-" << id << "\");'>"
<< " <i class='bi bi-dash-square' title='Collapse block'></i>"
<< "</button>"
<< "<button class='icon-btn' id='viz-" << id << "-show' style = 'display:none;' onclick='return toggleViz(\"viz-" << id << "\");'>"
<< " <i class='bi bi-plus-square' title='Expand block'></i>"
<< "</button>";
}
// Prints the box title within the div.box-header
void print_box_title(const std::string &title, const std::string &anchor) {
print_opening_tag("div", "box-title");
print_html_element("span", "", title, anchor);
print_closing_tag("div");
}
// Prints the cost indicator buttons within div.box-header
void print_cost_buttons(int id, const IRNode *op) {
print_opening_tag("div", "viz-cost-btns");
// Print compute cost indicator
int max_line_ccost = cost_model.get_max_compute_cost(false);
int line_ccost = cost_model.get_compute_cost(op, false);
int block_ccost = cost_model.get_compute_cost(op, true);
print_cost_button(line_ccost, block_ccost, max_line_ccost, "vcc-" + std::to_string(id), "Op Count: ");
// Print data movement cost indicator
int max_line_dcost = cost_model.get_max_data_movement_cost(false);
int line_dcost = cost_model.get_data_movement_cost(op, false);
int block_dcost = cost_model.get_data_movement_cost(op, true);
// Special handling for Store nodes; since unlike the code view
// the viz view prints stores and loads seperately, therefore using
// inclusive cost is confusing.
if (op->node_type == IRNodeType::Store) {
const Store *st = static_cast<const Store *>(op);
line_dcost = st->value.type().bits() * st->value.type().lanes();
block_dcost = line_dcost;
}
print_cost_button(line_dcost, block_dcost, max_line_dcost, "vdc-" + std::to_string(id), "Bits Moved: ");
print_closing_tag("div");
}
void print_cost_button(int line_cost, int block_cost, int max_line_cost, const std::string &id, const std::string &prefix) {
const int num_cost_buckets = 20;
int line_cost_bin_size = (max_line_cost / num_cost_buckets) + 1;
int block_cost_bin_size = (max_line_cost / num_cost_buckets) + 1;
int line_costc = line_cost / line_cost_bin_size;
int block_costc = block_cost / block_cost_bin_size;
if (line_costc >= num_cost_buckets) {
line_costc = num_cost_buckets - 1;
}
if (block_costc >= num_cost_buckets) {
block_costc = num_cost_buckets - 1;
}
stream << "<div id='" << id << "' class='cost-btn CostColor" << line_costc << "'"
<< " aria-describedby='tooltip-" << id << "'"
<< " line-cost='" << line_cost << "' block-cost='" << block_cost << "'"
<< " line-cost-color='" << line_costc << "' block-cost-color='" << block_costc << "'>"
<< "</div>";
stream << "<span id='tooltip-" << id << "' class='tooltip cond-tooltop' role='tooltip-" << id << "'>"
<< prefix << line_cost
<< "</span>";
}
// Prints the box .box-header within div.box
void print_box_header(int id, const IRNode *op, const std::string &anchor, const std::string &code_anchor, const std::string &title) {
print_opening_tag("div", "box-header");
print_collapse_expand_btn(id);
print_code_button(code_anchor);
print_box_title(title, anchor);
print_cost_buttons(id, op);
print_closing_tag("div");
}
// Prints the box .box-header within div.box, contains the asm info button
void print_box_header_asm(int id, const IRNode *op, const std::string &anchor, const std::string &code_anchor, const std::string &asm_anchor, const std::string &title) {
print_opening_tag("div", "box-header");
print_collapse_expand_btn(id);
print_code_button(code_anchor);
print_asm_button(asm_anchor);
print_box_title(title, anchor);
print_cost_buttons(id, op);
print_closing_tag("div");
}
// Converts an expr to a string without printing to stream
std::string get_as_str(const Expr &e) {
return get_as_str(e, "");
}
std::string get_as_str(const Expr &e, const std::string &prefix) {
if (prefix == "Else") {
return "Else";
}
std::ostringstream ss;
HTMLCodePrinter<std::ostringstream> printer(ss, node_ids);
e.accept(&printer);
std::string html_e = ss.str();
if (large_expr(e)) {
return prefix + truncate_html(html_e);
} else {
return prefix + html_e;
}
}
// Return variable name wrapped with html that enables matching
std::string get_as_var(const std::string &name) {
return "<b class='variable matched'>" + name + "</b>";
}
// Sometimes the expressions are too large to show within the viz. In
// such cases we use tooltips.
bool large_expr(const Expr &e) {
std::ostringstream ss;
ss << e;
return ss.str().size() > 50;
}
std::string truncate_html(const std::string &cond) {
int id = gen_unique_id();
std::ostringstream ss;
// Show condition expression button
ss << "<button title='Click to see path condition' id='cond-" << id << "' aria-describedby='cond-tooltip-" << id << "' class='trunc-cond' role='button'>"
<< "..."
<< "</button>";
// Tooltip that shows condition expression
ss << "<span id='cond-tooltip-" << id << "' class='tooltip cond-tooltop' role='cond-tooltip-" << id << "'>"
<< cond
<< "</span>";
return ss.str();
}
// Prints a single node in an `if-elseif-...-else` chain
void print_if_tree_node(const Stmt &node, const Expr &cond, const std::string &prefix) {
// Assign unique id to this node
int box_id = gen_unique_id();
int node_id = gen_node_id(node.get());
// Start tree node
print_opening_tag("li", "");
print_opening_tag("span", "tf-nc if-node");
// Start a box to hold viz
print_opening_tag("div", "box center IfBox");
// Create viz content
std::string aid = std::to_string(node_id);
print_box_header(box_id, node.get(), "cond-viz-" + aid, "cond-" + aid, get_as_str(cond, prefix));
// Print contents of node
print_opening_tag("div", "box-body", "viz-" + std::to_string(box_id));
node.accept(this);
print_closing_tag("div");
// Close box holding viz
print_closing_tag("div");
// Close tree node
print_closing_tag("span");
print_closing_tag("li");
}
/* Visitor functions for each IR node */
using IRVisitor::visit;
/* Override key visit functions */
void visit(const Allocate *op) override {
// Assign unique id to this node
int id = gen_node_id(op);
// Start a box to hold viz
print_opening_tag("div", "box center AllocateBox");
// Print box header
std::string aid = std::to_string(id);
print_box_header(id, op, "allocate-viz-" + aid, "allocate-" + aid, "Allocate: " + op->name);
// Start a box to hold viz
print_opening_tag("div", "box-body", "viz-" + std::to_string(id));
// Generate a table with allocation details
print_opening_tag("table", "allocate-table");
// - Memory type
stream << "<tr><th scope='col'>Memory Type</th><td>" << op->memory_type << "</td></tr>";
// - Allocation condition
if (!is_const_one(op->condition)) {
stream << "<tr><th scope='col'>Condition</th><td>" << op->condition << "</td></tr>";
}
// - Data type
stream << "<tr><th scope='col'>Data Type</th><td>" << op->type << "</td></tr>";
// - Dimensions
for (size_t i = 0; i < op->extents.size(); i++) {
stream << "<tr><th scope='col'>Dim-" << i << "</th><td>" << get_as_str(op->extents[i]) << "</td></tr>";
}
print_closing_tag("table");
op->body.accept(this);
print_closing_tag("div");
print_closing_tag("div");
}
void visit(const For *op) override {
// Assign unique id to this node
int id = gen_node_id(op);
// Start a box to hold viz
print_opening_tag("div", "box center ForBox");
// Print box header
std::string aid = std::to_string(id);
int asm_lno = assembly_info.get_asm_lno((uint64_t)op);
if (asm_lno == -1) {
print_box_header(id, op, "loop-viz-" + aid, "loop-" + aid, "For: " + get_as_var(op->name));
} else {
print_box_header_asm(id, op, "loop-viz-" + aid, "loop-" + aid, std::to_string(asm_lno), "For: " + get_as_var(op->name));
}
// Start a box to hold viz
print_opening_tag("div", "box-body", "viz-" + std::to_string(id));
// Generate a table with loop details
print_opening_tag("table", "allocate-table");
// - Loop type
if (op->for_type != ForType::Serial) {
stream << "<tr><th scope='col'>Loop Type</th><td>" << op->for_type << "</td></tr>";
}
// - Device API
if (op->device_api != DeviceAPI::None) {
stream << "<tr><th scope='col'>Device API</th><td>" << op->device_api << "</td></tr>";
}
// - Min
stream << "<tr><th scope='col'>Min</th><td>" << get_as_str(op->min) << "</td></tr>";
// - Extent
stream << "<tr><th scope='col'>Extent</th><td>" << get_as_str(op->extent) << "</td></tr>";
print_closing_tag("table");
op->body.accept(this);
print_closing_tag("div");
print_closing_tag("div");
}
void visit(const IfThenElse *op) override {
// Open If tree
print_opening_tag("div", "tf-tree tf-gap-sm tf-custom-ir-viz");
// Create root 'cond' node
if (op->else_case.defined()) {
print_opening_tag("ul", "");
print_opening_tag("li", "");
print_html_element("span", "tf-nc if-node if-root-node", "Control Flow Branching");
}
// Create children nodes ('then', 'else if' and 'else' cases)
print_opening_tag("ul", "");
// `then` case
print_if_tree_node(op->then_case, op->condition, "If: ");
// `else if` cases
const IfThenElse *nested_if = op;
while (nested_if->else_case.defined() && (nested_if->else_case.as<IfThenElse>())) {
nested_if = nested_if->else_case.as<IfThenElse>();
print_if_tree_node(nested_if->then_case, nested_if->condition, "Else If: ");
}
// `else` case
if (nested_if->else_case.defined()) {
print_if_tree_node(nested_if->else_case, UIntImm::make(UInt(1), 1), "Else");
}
print_closing_tag("ul");
// Close If tree
if (op->else_case.defined()) {
print_closing_tag("li");
print_closing_tag("ul");
}
print_closing_tag("div");
}
void visit(const ProducerConsumer *op) override {
// Assign unique id to this node
int id = gen_node_id(op);
// Start a box to hold viz
std::string box_name = op->is_producer ? "ProducerBox" : "ConsumerBox";
print_opening_tag("div", "box center " + box_name);
// Print box header
std::string aid = std::to_string(id);
std::string prefix = op->is_producer ? "Produce: " : "Consume: ";
int asm_lno = assembly_info.get_asm_lno((uint64_t)op);
if (asm_lno == -1) {
print_box_header(id, op, "prodcons-viz-" + aid, "prodcons-" + aid, prefix + get_as_var(op->name));
} else {
print_box_header_asm(id, op, "prodcons-viz-" + aid, "prodcons-" + aid, std::to_string(asm_lno), prefix + get_as_var(op->name));
}
// Print the body
print_opening_tag("div", "box-body", "viz-" + std::to_string(id));
op->body.accept(this);
print_closing_tag("div");
// Close div holding the producer/consumer
print_closing_tag("div");
}
void visit(const Store *op) override {
// Visit the value first. We want to show any loads
// that happen before the store operation
op->value.accept(this);
// Assign unique id to this node
int id = gen_node_id(op);
// Start a box to hold viz
print_opening_tag("div", "box center StoreBox");
// Print box header
std::string aid = std::to_string(id);
print_box_header(id, op, "store-viz-" + aid, "store-" + aid, "Store: " + get_as_var(op->name));
// Start a box to hold viz
print_opening_tag("div", "box-body", "viz-" + std::to_string(id));
// Generate a table with store details
print_opening_tag("table", "allocate-table");
// - Store predicate
if (!is_const_one(op->predicate)) {
stream << "<tr><th scope='col'>Predicate</th><td>" << get_as_str(op->predicate) << "</td></tr>";
}
// - Alignment
const bool show_alignment = op->value.type().is_vector() && op->alignment.modulus > 1;
if (show_alignment) {
stream << "<tr><th scope='col'>Alignment</th><td>"
<< "aligned(" << op->alignment.modulus << ", " << op->alignment.remainder << ")"
<< "</td></tr>";
}
// - Qualifiers
if (op->value.type().is_vector()) {
const Ramp *idx = op->index.as<Ramp>();
if (idx && is_const_one(idx->stride)) {
stream << "<tr><th scope='col'>Type</th><td>Dense Vector</td></tr>";
} else {
stream << "<tr><th scope='col'>Type</th><td>Strided Vector</td></tr>";
}
stream << "<tr><th scope='col'>Output Tile</th><td>" << op->value.type() << "</td></tr>";
} else {
stream << "<tr><th scope='col'>Type</th><td>Scalar</td></tr>";
stream << "<tr><th scope='col'>Output</th><td>" << op->value.type() << "</td></tr>";
}
print_closing_tag("table");
print_closing_tag("div");
print_closing_tag("div");
}
void visit(const Load *op) override {
// Assign unique id to this node
int id = gen_node_id(op);
// Start a box to hold viz
print_opening_tag("div", "box center LoadBox");
// Print box header
std::string aid = std::to_string(id);
print_box_header(id, op, "load-viz-" + aid, "load-" + aid, "Load: " + get_as_var(op->name));
// Start a box to hold viz
print_opening_tag("div", "box-body", "viz-" + std::to_string(id));
// Generate a table with load details
print_opening_tag("table", "allocate-table");
// - Load predicate
if (!is_const_one(op->predicate)) {
stream << "<tr><th scope='col'>Predicate</th><td>" << get_as_str(op->predicate) << "</td></tr>";
}
// - Alignment
const bool show_alignment = op->type.is_vector() && op->alignment.modulus > 1;
if (show_alignment) {
stream << "<tr><th scope='col'>Alignment</th><td>"
<< "aligned(" << op->alignment.modulus << ", " << op->alignment.remainder << ")"
<< "</td></tr>";
}
// - Qualifiers
if (op->type.is_vector()) {
const Ramp *idx = op->index.as<Ramp>();
if (idx && is_const_one(idx->stride)) {
stream << "<tr><th scope='col'>Type</th><td>Dense Vector</td></tr>";
} else {
stream << "<tr><th scope='col'>Type</th><td>Strided Vector</td></tr>";
}
stream << "<tr><th scope='col'>Output Tile</th><td>" << op->type << "</td></tr>";
} else {
stream << "<tr><th scope='col'>Type</th><td>Scalar</td></tr>";
stream << "<tr><th scope='col'>Output</th><td>" << op->type << "</td></tr>";
}
print_closing_tag("table");
print_closing_tag("div");
print_closing_tag("div");
}
void visit(const Call *op) override {
int id = gen_node_id(op);
// Add viz support for key functions/intrinsics
if (op->name == "halide_do_par_for") {
print_fn_button(op->name, id);
} else if (op->name == "halide_do_par_task") {
print_fn_button(op->name, id);
} else if (op->name == "_halide_buffer_init") {
print_fn_button(op->name, id);
} else if (op->name.rfind("_halide", 0) != 0) {
// Assumption: We want to ignore intrinsics starting with _halide
// but for everything else, generate a warning
debug(2) << "Function call ignored by IRVisualizer: " << op->name << "\n";
}
}
};
/** IRVisualizer Class
* Generates the output html page. Currently the html page has
* three key tabs: IR code, Visualized pipeline and the generated
* assembly.
*/
class IRVisualizer {
public:
// Construct the visualizer and point it to the output file
explicit IRVisualizer(const std::string &html_output_filename,
const Module &m,
const std::string &assembly_input_filename)
: html_code_printer(stream, node_ids), html_viz_printer(stream, node_ids) {
// Open output file
stream.open(html_output_filename.c_str());
// Load assembly code -- if not explicit specified, assume it will have matching pathname
// as our output file, with a different extension.
if (assembly_input_filename.empty()) {
// get_output_info() is the One True Source Of Truth for expected file suffixes,
// so use that rather than hardcoding it here.
const auto info = get_output_info(m.target());
const std::string stmt_html_extension = info.at(OutputFileType::stmt_html).extension;
const size_t pos = html_output_filename.rfind(stmt_html_extension);
user_assert(pos != std::string::npos)
<< "Unable to find expected extension (" << html_output_filename
<< ") in filename (" << html_output_filename << ")\n";
const std::string asm_extension = info.at(OutputFileType::assembly).extension;
const std::string asm_file = html_output_filename.substr(0, pos) + asm_extension;
load_asm_code(asm_file);
} else {
load_asm_code(assembly_input_filename);
}
}
// Generate the html visualization of the input module
void generate_html(const Module &m) {
// Before we generate any html, we annotate IR nodes with
// line numbers containing corresponding assembly code. This
// code is based on darya-ver's original implementation. We
// use comments in the generated assembly to infer association
// between Halide IR and assembly -- unclear how reliable this is.
asm_info.generate(asm_stream.str(), m);
// Run the cost model over this module to pre-compute all
// node costs
cost_model.compute_all_costs(m);
html_code_printer.init_cost_info(cost_model);
html_viz_printer.init_cost_info(cost_model);
// Generate html page
stream << "<html>\n";
generate_head(m);
generate_body(m);
stream << "</html>";
}
private:
// Handle to output file stream
std::ofstream stream;
// Handle to assembly file stream
std::ifstream assembly;
// Holds cost information for visualized program
IRCostModel cost_model;
// Annotate AST nodes with unique IDs
std::map<const IRNode *, int> node_ids;
// Used to translate IR to code in HTML
HTMLCodePrinter<std::ofstream> html_code_printer;
// Used to translate IR to visualization in HTML
HTMLVisualizationPrinter html_viz_printer;
/* Methods for generating the <head> section of the html file */
void generate_head(const Module &m) {
stream << "<head>\n";
stream << "<title>Visualizing Module: " << m.name() << "</title>\n";
generate_dependency_links();
generate_stylesheet();
stream << "</head>\n";
}
// Loads the html code responsible for linking with various js/css libraries from
// `ir_visualizer/dependencies.html`
void generate_dependency_links() {
stream << halide_html_template_StmtToViz_dependencies;
}
// Loads the stylesheet code from `ir_visualizer/stylesheet.html`
void generate_stylesheet() {
stream << halide_html_template_StmtToViz_stylesheet;
}
/* Methods for generating the <body> section of the html file */
void generate_body(const Module &m) {
stream << "<body>\n";
stream << " <div id='page-container'>\n";
generate_visualization_tabs(m);
stream << " </div>\n";
stream << "</body>";
generate_javascript();
}
// Generate the three visualization tabs
void generate_visualization_tabs(const Module &m) {
stream << "<div id='visualization-tabs'>\n";
generate_ir_tab(m);
generate_resize_bar_1();
generate_visualization_tab(m);
generate_resize_bar_2();
generate_assembly_tab(m);
stream << "</div>\n";
}
// Generate tab 1/3: Lowered IR code with syntax highlighting in HTML
void generate_ir_tab(const Module &m) {
stream << "<div id='ir-code-tab'>\n";
html_code_printer.print(m);
stream << "</div>\n";
}
// Generate tab 2/3: Lowered IR code with syntax highlighting in HTML
void generate_visualization_tab(const Module &m) {
stream << "<div id='ir-visualization-tab'>\n";
html_viz_printer.print(m, asm_info);
stream << "</div>\n";
}
// Generate tab 3/3: Generated assembly code
void generate_assembly_tab(const Module &m) {
stream << "<div id='assembly-tab'>\n";
stream << "<div id='assemblyContent' style='display: none;'>\n";
stream << "<pre>\n";
stream << asm_stream.str();
stream << "</pre>\n";
stream << "</div>\n";
stream << "</div>\n";
}
// Generate a resizing bar to control the width of code and visualization tabs
void generate_resize_bar_1() {
stream << R"(<div class='resize-bar' id='resize-bar-1'>
<div class='collapse-btns'>
<div>
<button class='icon-btn resize-btn' onclick='collapse_code_tab()'>
<i class='bi bi-arrow-bar-left' title='Collapse code tab'></i>
</button>
</div>
<div>
<button class='icon-btn resize-btn' onclick='collapseR_visualization_tab()'>
<i class='bi bi-arrow-bar-right' title='Collapse visualization tab'></i>
</button>
</div>
</div>
</div>)";
}
// Generate a resizing bar to control the width of visualization and assembly tabs
void generate_resize_bar_2() {
stream << R"(<div class='resize-bar' id='resize-bar-2'>
<div class='collapse-btns'>
<div>
<button class='icon-btn resize-btn' onclick='collapseL_visualization_tab()'>
<i class='bi bi-arrow-bar-left' title='Collapse visualization tab'></i>
</button>
</div>
<div>
<button class='icon-btn resize-btn' onclick='collapse_assembly_tab()'>
<i class='bi bi-arrow-bar-right' title='Collapse assembly tab'></i>
</button>
</div>
</div>
</div>)";
}
// Loads and initializes the javascript template from `ir_visualizer / javascript_template.html`
void generate_javascript() {
stream << halide_html_template_StmtToViz_javascript;
}
/* Misc helper methods */
// Load assembly code from file
std::ostringstream asm_stream;
AssemblyInfo asm_info;
void load_asm_code(const std::string &asm_file) {
user_assert(file_exists(asm_file)) << "Unable to open assembly file: " << asm_file << "\n";
// Open assembly file
assembly.open(asm_file.c_str());
// Slurp the code into asm_stream
std::string line;
while (getline(assembly, line)) {
asm_stream << line << "\n";
}
}
};
// The external interface to this module
void print_to_viz(const std::string &html_output_filename,
const Module &m,
const std::string &assembly_input_filename) {
IRVisualizer visualizer(html_output_filename, m, assembly_input_filename);
visualizer.generate_html(m);
debug(1) << "Done generating HTML IR Visualization - printed to: " << html_output_filename << "\n";
}
} // namespace Internal
} // namespace Halide
Computing file changes ...