#include #include #include #include "Associativity.h" #include "BoundaryConditions.h" #include "CSE.h" #include "Debug.h" #include "Derivative.h" #include "DerivativeUtils.h" #include "Error.h" #include "ExprUsesVar.h" #include "FindCalls.h" #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" #include "RealizationOrder.h" #include "Simplify.h" #include "Solve.h" #include "Substitute.h" namespace Halide { using std::map; using std::set; using std::string; using std::vector; using FuncKey = Derivative::FuncKey; namespace Internal { namespace { bool is_float_extern(const string &op_name, const string &func_name) { return op_name == (func_name + "_f16") || op_name == (func_name + "_f32") || op_name == (func_name + "_f64"); }; /** Compute derivatives through reverse accumulation */ class ReverseAccumulationVisitor : public IRVisitor { public: void propagate_adjoints(const Func &output, const Func &adjoint, const Region &output_bounds); map get_adjoint_funcs() const { return adjoint_funcs; } protected: void visit(const IntImm *) override; void visit(const UIntImm *) override; void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *op) override; void visit(const Reinterpret *op) override; void visit(const Variable *op) override; void visit(const Add *op) override; void visit(const Sub *op) override; void visit(const Mul *op) override; void visit(const Div *op) override; void visit(const Mod *op) override; void visit(const Min *op) override; void visit(const Max *op) override; void visit(const EQ *op) override; void visit(const NE *op) override; void visit(const LT *op) override; void visit(const LE *op) override; void visit(const GT *op) override; void visit(const GE *op) override; void visit(const And *) override; void visit(const Or *) override; void visit(const Not *) override; void visit(const Select *op) override; void visit(const Let *op) override; void visit(const Call *op) override; void visit(const Load *op) override { internal_error << "Encounter unexpected expression \"Load\" when differentiating."; } void visit(const Ramp *op) override { internal_error << "Encounter unexpected expression \"Ramp\" when differentiating."; } void visit(const Broadcast *op) override { internal_error << "Encounter unexpected expression \"Broadcast\" when differentiating."; } void visit(const Shuffle *op) override { internal_error << "Encounter unexpected expression \"Shuffle\" when differentiating."; } void visit(const VectorReduce *op) override { internal_error << "Encounter unexpected expression \"VectorReduce\" when differentiating."; } void visit(const LetStmt *op) override { internal_error << "Encounter unexpected statement \"LetStmt\" when differentiating."; } void visit(const AssertStmt *op) override { internal_error << "Encounter unexpected statement \"AssertStmt\" when differentiating."; } void visit(const ProducerConsumer *op) override { internal_error << "Encounter unexpected statement \"ProducerConsumer\" when differentiating."; } void visit(const For *op) override { internal_error << "Encounter unexpected statement \"For\" when differentiating."; } void visit(const Store *op) override { internal_error << "Encounter unexpected statement \"Store\" when differentiating."; } void visit(const Provide *op) override { internal_error << "Encounter unexpected statement \"Provide\" when differentiating."; } void visit(const Allocate *op) override { internal_error << "Encounter unexpected statement \"Allocate\" when differentiating."; } void visit(const Free *op) override { internal_error << "Encounter unexpected statement \"Free\" when differentiating."; } void visit(const Realize *op) override { internal_error << "Encounter unexpected statement \"Realize\" when differentiating."; } void visit(const Block *op) override { internal_error << "Encounter unexpected statement \"Block\" when differentiating."; } void visit(const IfThenElse *op) override { internal_error << "Encounter unexpected statement \"IfThenElse\" when differentiating."; } void visit(const Evaluate *op) override { internal_error << "Encounter unexpected statement \"Evaluate\" when differentiating."; } void visit(const Prefetch *op) override { internal_error << "Encounter unexpected statement \"Prefetch\" when differentiating."; } void visit(const Fork *op) override { internal_error << "Encounter unexpected statement \"Fork\" when differentiating."; } void visit(const Acquire *op) override { internal_error << "Encounter unexpected statement \"Acquire\" when differentiating."; } void visit(const Atomic *op) override { internal_error << "Encounter unexpected statement \"Atomic\" when differentiating."; } void visit(const HoistedStorage *op) override { internal_error << "Encounter unexpected statement \"HoistedStorage\" when differentiating."; } private: void accumulate(const Expr &stub, Expr adjoint); void propagate_halide_function_call( Expr adjoint, const std::string &name, // called function name const FunctionPtr &func_ptr, // pointer to halide function, is null if this is a call to buffer or param const std::vector &call_args, // call arguments int value_index, // which element in the tuple const Type &type // return type of the called function ); // For each expression, we store the accumulated adjoints expression map expr_adjoints; // For each function and each update, we store the accumulated adjoints func map adjoint_funcs; // Let variables and their mapping map let_var_mapping; vector let_variables; // Bounds of functions map func_bounds; // Current function that scatters its adjoints to its dependencies Func current_func; // Current update of the function int current_update_id; // We compute the derivatives in several passes. // Sometimes we don't want to propagate through Halide function calls bool is_forward_overwrite_detection_phase; bool is_self_referencing_phase; // Is the current function update a non overwriting scan? bool is_current_non_overwriting_scan; // A temporary flag for checking the derivatives // to self reference of a Halide function is 1 or not // Used in forward overwrite detection phase Tuple self_reference_adjoint = Tuple(Expr()); vector> self_reference_args; }; void ReverseAccumulationVisitor::propagate_adjoints( const Func &output, const Func &adjoint, const Region &output_bounds) { // Topologically sort the functions map env = find_transitive_calls(output.function()); vector order = realization_order({output.function()}, env).first; vector funcs; funcs.reserve(order.size()); // Internal::debug(0) << "Sorted Func list:\n"; // for (const auto &func_name : order) { // Internal::debug(0) << " . " << func_name << "\n"; // } for (const auto &func_name : order) { funcs.emplace_back(env[func_name]); } internal_assert(!funcs.empty()); // If the derivatives depend on an in-place overwrite, // and the self reference adjoint is not 0 or 1, // throws an error to the users. // For example: // // 1. // f(x) = g(x) // f(x) = f(x) * f(x) // f'(x) depends on first f(x) // // 2. // f(x) = 0 // f(x) = 2 * f(x) + g(r.x) // g'(r.x) depends on intermediate f'(x) // // The following is fine because the self reference adjoint is 1: // f(x) = f(x) + g(r.x) // (when it's 1 all instances of f(x) have the same adjoint) // // The issue is that the self reference to f makes propagation to g // using the wrong adjoints. // // The user should rewrite the above updates to the following: // // 1. // f_(x, 0) = g(x) // f_(x, 1) = f_(x, 0) * f_(x, 0) // f(x) = f_(x, 1) // // 2. // f_(x, 0) = 0 // f_(x, r.x + 1) = 2 * f_(x, r.x) + g(r.x) // f(x) = f_(x, r.x.max() + 1) // // We can do the rewrite for the users automatically, but it requires // generating the indirect reference f_, making scheduling these // functions extremely difficult. is_forward_overwrite_detection_phase = true; set non_overwriting_scans; for (auto &func : funcs) { current_func = func; // Precompute the left hand side intervals for each update // We use this to determine if there's overlaps between the updates vector boxes; boxes.reserve(func.num_update_definitions()); for (int update_id = 0; update_id < func.num_update_definitions(); update_id++) { const vector &args = func.update_args(update_id); vector intervals; intervals.reserve(args.size()); for (const auto &arg : args) { Scope scope; ReductionDomain rdom = extract_rdom(arg); if (rdom.defined()) { const vector &rvars = rdom.domain(); for (const auto &r : rvars) { Expr r_max = simplify(r.min + r.extent + 1); scope.push(r.var, Interval(r.min, r_max)); } } Interval interval = bounds_of_expr_in_scope(arg, scope); intervals.push_back(interval); } boxes.emplace_back(intervals); } for (int update_id = 0; update_id < func.num_update_definitions(); update_id++) { // We check for two criteria: // 1. We check if the derivatives // depend on previous update, and if that particular // value has been overwritten. // 2. For updates of f with reduction variables, // unless the derivatives to self reference is 1 or 0, // we make sure overwritten f' is not used by others. // We conservatively detect this by distinguish two cases: // a. If f' is always never being overwritten for all instances of // the reduction variables // b. Or if f' is never used by others except itself. // // A few examples: // // f(x) = f(x) + g(r.x) // good, the self update derivative is 1 // // f(x) = 2 * f(x) // good, although the self update derivative is 2, // there's no reduction variables // // f(x) = 2 * f(x) + g(r.x) // bad, f'(x) will be used for updating // g(r.x) but will be overwritten // // f(x) = f(x) * f(x) // bad, derivative of f(x) depends on previous value // which has been overwritten // // f(x, 0) = ... // f(x, 1) = f(x, 0) * f(x, 0) // good, although the derivative depends on // // previous value, the updates do not overlap // // f(x, r.x + 1) = 2 * f(x, r.x) + g(r.x) // good, // // f' is never overwritten // // f(x, y) = g(x) // f(x, r.x + 1) = f(x, r.x) * f(x, r.x); // bad, the derivatives // depend on previous updates // // f(x, y, 0) = g(x) // f(x, r.x + 1, 1) = f(x, r.x, 0) * f(x, r.x, 0); // good // // f(x, r.x + 1, r.y + 1) = 2 * f(x, r.x, r.y) + g(r.x) // good // // f(x, r.x + 1, r.x + r.y + 1) = 2 * f(x, r.x, r.y) + g(r.x) // bad vector zeros; Tuple rhs_tuple = func.values(); zeros.reserve(rhs_tuple.size()); for (int i = 0; i < (int)rhs_tuple.size(); i++) { zeros.push_back(make_zero(rhs_tuple[i].type())); } self_reference_adjoint = Tuple(zeros); self_reference_args.clear(); // Checking 1. here: // Take the derivative at expression level, the results are // stored in expr_adjoints vector expr_list; Tuple update_tuple = func.update_values(update_id); vector output_exprs; const vector &update_tuple_vector = update_tuple.as_vector(); for (const auto &expr : update_tuple_vector) { vector value_expr_list = sort_expressions(expr); expr_list.insert(expr_list.end(), value_expr_list.begin(), value_expr_list.end()); output_exprs.push_back((const BaseExprNode *)expr_list.back().get()); } // TODO: replace let_var_mapping with Scope // Gather let variables let_var_mapping.clear(); let_variables.clear(); for (const auto &expr : expr_list) { if (expr.get()->node_type == IRNodeType::Let) { const Let *op = expr.as(); // Assume Let variables are unique internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end()); let_var_mapping[op->name] = op->value; let_variables.push_back(op->name); } } // Set the output adjoint to 1 // We're not really propagating adjoints, just checking if there's // self references for (auto &output_expr : output_exprs) { expr_adjoints[output_expr] = 1.f; } // Traverse the expressions in reverse order for (auto it = expr_list.rbegin(); it != expr_list.rend(); it++) { if (it->type().is_handle()) { // Ignore pointer types continue; } it->accept(this); } auto error = [&]() { user_error << "Can't take the gradients of " << func.name() << ", which depend on intermediate values. " << "Use a scan (which saves intermediate results) instead."; }; // For each adjoint expression depositing to a function or image, // check if it references to the function bool adjoints_used_by_others = false; for (const auto &it : expr_adjoints) { Expr target_expr(it.first); bool is_target_func_or_buffer = false; const Call *call_op = target_expr.as(); if (call_op != nullptr) { is_target_func_or_buffer = call_op->call_type == Call::Image || call_op->call_type == Call::Halide; } Expr expr = it.second; if (is_target_func_or_buffer && is_calling_function(func.name(), expr, let_var_mapping)) { // Self reference might not be bad. // If we carefully avoid overwriting intermediate values, // we can still backprop. // First we check for the pure definition. // If the pure definition depends on any functions or buffers, // there is no hope since we will overwrite something Tuple rhs_tuple = func.values(); for (int tuple_id = 0; tuple_id < (int)rhs_tuple.size(); tuple_id++) { if (is_calling_function(rhs_tuple[tuple_id], let_var_mapping)) { error(); } } // Now we check all previous updates, see if the left hand // side arguments overlap. Box current_box = boxes[update_id]; for (int prev_update_id = 0; prev_update_id < update_id; prev_update_id++) { // Gather two boxes from current update and previous update Box prev_box = boxes[prev_update_id]; internal_assert(current_box.size() == prev_box.size()); // If any of the boxes overlap, we need to throw an error if (boxes_overlap(current_box, prev_box)) { error(); } } } if (is_target_func_or_buffer && call_op->name != func.name()) { adjoints_used_by_others = true; } } expr_adjoints.clear(); // Checking 2. here: bool all_zero_or_one_self_adjoint = true; for (int i = 0; i < (int)self_reference_adjoint.size(); i++) { if (!is_const(self_reference_adjoint[i], 0) && !is_const(self_reference_adjoint[i], 1)) { all_zero_or_one_self_adjoint = false; break; } } bool has_reduction_var = !func.rvars(update_id).empty(); if (!all_zero_or_one_self_adjoint && has_reduction_var) { // a. is there any instance of reduction variable such that // the self reference update overwrites itself? // Or, equivalently, for all possible values of the reduction // variables, does the self reference update always // reads from/writes to different locations? // First we determine the ranges of RDoms for // and_condition_over_domain Scope varying; // Loop over lhs & rhs to grab a reduction domain ReductionDomain r; const vector &update_args = func.update_args(update_id); for (const Expr &expr : update_args) { r = extract_rdom(expr); if (r.defined()) { break; } } if (!r.defined()) { for (int tuple_id = 0; tuple_id < (int)update_tuple.size(); tuple_id++) { r = extract_rdom(update_tuple[tuple_id]); if (r.defined()) { break; } } } internal_assert(r.defined()); // Go over all self reference call arguments bool is_not_overwriting = true; for (const vector &self_ref_args : self_reference_args) { internal_assert(self_ref_args.size() == update_args.size()); Expr not_overwriting_cond = const_false(); for (int arg_id = 0; arg_id < (int)self_ref_args.size(); arg_id++) { // Are the read from/write to arguments always different? not_overwriting_cond = simplify(not_overwriting_cond || (self_ref_args[arg_id] != update_args[arg_id])); } not_overwriting_cond = and_condition_over_domain( not_overwriting_cond, varying); // Needs to be true for all self reference is_not_overwriting = is_not_overwriting && can_prove(not_overwriting_cond); } // b. Even if the derivative is overwritten, as long as // we don't use it in this update we are good. // Otherwise we throw an error if (!is_not_overwriting && adjoints_used_by_others) { error(); } if (is_not_overwriting) { // This is a non overwriting scan, let's remember it non_overwriting_scans.insert(FuncKey{func.name(), update_id}); } } } } is_forward_overwrite_detection_phase = false; // Bounds inference Box output_box; for (const auto &p : output_bounds) { // Convert from min,extent to min,max output_box.push_back(Interval(p.min, p.min + p.extent)); } func_bounds = inference_bounds(output, output_box); for (const auto &it : func_bounds) { const Box &bounds = it.second; for (int d = 0; d < (int)bounds.size(); d++) { user_assert(bounds[d].is_bounded()) << "Access to function or buffer " << it.first << " at dimension " << d << " is not bounded. " << "We can only differentiate bounded accesses.\n"; } } // Create a stub for each function and each update to accumulate adjoints. for (int func_id = 0; func_id < (int)funcs.size(); func_id++) { const Func &func = funcs[func_id]; for (int update_id = -1; update_id < func.num_update_definitions(); update_id++) { Func adjoint_func(func.name() + "_" + std::to_string(update_id + 1) + "_d_def__"); bool is_final_output = func_id == (int)funcs.size() - 1 && update_id == func.num_update_definitions() - 1; vector args = func.args(); for (auto &arg : args) { if (arg.is_implicit()) { // Replace implicit variables with non implicit ones arg = Var(); } } if (is_final_output) { adjoint_func(args) = adjoint(args); } else { // Initialize to 0 if (func.values().size() == 1) { adjoint_func(args) = make_zero(func.values()[0].type()); } else { vector init(func.values().size()); for (int i = 0; i < (int)init.size(); i++) { init[i] = make_zero(func.values()[i].type()); } adjoint_func(args) = Tuple(init); } } FuncKey func_key{func.name(), update_id}; internal_assert(adjoint_funcs.find(func_key) == adjoint_funcs.end()); adjoint_funcs[func_key] = adjoint_func; } } // Also create stubs for buffers referenced by the functions map called_buffers_or_param; for (auto &func : funcs) { map buffers = find_buffer_param_calls(func); called_buffers_or_param.insert(buffers.begin(), buffers.end()); } for (const auto &it : called_buffers_or_param) { // Replace all the dots in the function names to make it legal. Func adjoint_func(replace_all(it.first, ".", "_") + "_d__"); vector args(it.second.dimension); adjoint_func(args) = make_zero(it.second.type); FuncKey func_key{it.first, -1}; if (adjoint_funcs.find(func_key) != adjoint_funcs.end()) { user_error << "Naming conflict between buffer/parameters and function:" << it.first << "\n"; } adjoint_funcs[func_key] = adjoint_func; } // Traverse functions from producers to consumers for reverse accumulation for (int func_id = funcs.size() - 1; func_id >= 0; func_id--) { const Func &func = funcs[func_id]; current_func = func; FuncKey func_key{func.name(), func.num_update_definitions() - 1}; // Traverse from the last update to first for (int update_id = func.num_update_definitions() - 1; update_id >= -1; update_id--) { current_update_id = update_id; FuncKey func_key{func.name(), update_id}; Func adjoint_func = adjoint_funcs[func_key]; internal_assert(func_bounds.find(func.name()) != func_bounds.end()); // The propagation of adjoints to self reference goes to // current update instead of previous if it's a non overwriting scan is_current_non_overwriting_scan = false; if (update_id >= 0) { auto it = non_overwriting_scans.find(func_key); if (it != non_overwriting_scans.end()) { is_current_non_overwriting_scan = true; } } // Initialize the next adjoint function by // propagating the adjoints to next update // Example: // f(x) = ... // f(1) = ... <- we're here // We have an adjoint for f(1) defined over the whole support of f // Now we want to initialize for the f(x) update // Need to propagate back to all x while masking 1 // x -> next_args // 1 -> update_args auto mask_previous_update = [&]() { FuncKey prev_func_key{func.name(), update_id - 1}; Func &prev_adjoint_func = adjoint_funcs[prev_func_key]; vector prev_args = prev_adjoint_func.args(); vector update_args = func.update_args(update_id); // Replace implicit variables for (auto &arg : update_args) { set implicit_variables = find_implicit_variables(arg); for (const auto &var : implicit_variables) { arg = substitute(var, prev_args[Var::implicit_index(var)], arg); } } // Check if prev_args are the same as update_args // If they are the same simply set everything to zero bool is_noop = true; for (int i = 0; i < (int)prev_args.size(); i++) { const Variable *update_var = update_args[i].as(); if (update_var == nullptr || prev_args[i].name() != update_var->name) { is_noop = false; } } prev_adjoint_func = Func(prev_adjoint_func.name()); if (!is_noop) { // f'(x) = adjoint prev_adjoint_func(prev_args) = adjoint_funcs[func_key](prev_args); if (func.values().size() == 1) { Type type = func.values()[0].type(); prev_adjoint_func(update_args) = make_zero(type); } else { vector init(func.values().size()); for (int i = 0; i < (int)init.size(); i++) { init[i] = make_zero(func.values()[i].type()); } prev_adjoint_func(update_args) = Tuple(init); } } else { if (func.values().size() == 1) { Type type = func.values()[0].type(); prev_adjoint_func(prev_args) = make_zero(type); } else { vector init(func.values().size()); for (int i = 0; i < (int)init.size(); i++) { init[i] = make_zero(func.values()[i].type()); } prev_adjoint_func(prev_args) = Tuple(init); } } }; if (update_id >= 0 && !is_current_non_overwriting_scan) { // Delay the masking if we're keeping track of intermediate values. // Since in this case we are propagating to current update, // instead of previous update. mask_previous_update(); } // Now we want to propagate the derivatives at expression level. // We topologically sort the expressions for each value in the tuple. vector expr_list; Tuple rhs_tuple = update_id < 0 ? func.values() : func.update_values(update_id); vector output_exprs; const vector &rhs_tuple_vector = rhs_tuple.as_vector(); for (const auto &expr : rhs_tuple_vector) { vector value_expr_list = sort_expressions(expr); expr_list.insert( expr_list.end(), value_expr_list.begin(), value_expr_list.end()); output_exprs.push_back((const BaseExprNode *)expr_list.back().get()); } // TODO: replace let_var_mapping with Scope // Gather let variables let_var_mapping.clear(); let_variables.clear(); for (const auto &expr : expr_list) { if (expr.get()->node_type == IRNodeType::Let) { const Let *op = expr.as(); // Assume Let variables are unique internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end()); let_var_mapping[op->name] = op->value; let_variables.push_back(op->name); } } // Retrieve previously propagated adjoint for the Func, // apply it to expression adjoints. // f(x) = g(x) // d_g(x) = d_f(x) * df/dg vector update_args; if (update_id >= 0) { update_args = func.update_args(update_id); } else { update_args.reserve(func.args().size()); Func adjoint_func = adjoint_funcs[func_key]; for (const auto &var : adjoint_func.args()) { update_args.push_back(var); } } // We propagate in two phases, the first phase only propagates // to self references, the second phase propagates to the rest. { // First phase is_self_referencing_phase = true; expr_adjoints.clear(); if (output_exprs.size() == 1) { expr_adjoints[output_exprs[0]] = (adjoint_funcs[func_key])(update_args); } else { for (int i = 0; i < (int)output_exprs.size(); i++) { expr_adjoints[output_exprs[i]] = (adjoint_funcs[func_key])(update_args)[i]; } } // Traverse the expressions in reverse order for (auto it = expr_list.rbegin(); it != expr_list.rend(); it++) { if (it->type().is_handle()) { // Ignore pointer types continue; } // Propagate adjoints it->accept(this); } } if (is_current_non_overwriting_scan) { // Now, if we detect a non-overwriting scan operation, // the update of adjoints goes to the current function. // We let the previous adjoint the same as the current one FuncKey prev_func_key{func_key.first, func_key.second - 1}; // Recreate a new adjoint for previous update Func prev_adjoint; vector args; args.reserve(adjoint_func.args().size()); for (const auto &arg : adjoint_func.args()) { args.push_back(arg); } vector calls; calls.reserve(rhs_tuple.size()); for (int i = 0; i < (int)rhs_tuple.size(); i++) { calls.push_back(Call::make( adjoint_funcs[func_key].function(), args, i)); } prev_adjoint(args) = Tuple(calls); adjoint_funcs[prev_func_key] = prev_adjoint; mask_previous_update(); } { // Second phase is_self_referencing_phase = false; expr_adjoints.clear(); for (int i = 0; i < (int)output_exprs.size(); i++) { expr_adjoints[output_exprs[i]] = Call::make(adjoint_funcs[func_key].function(), update_args, i); } // Traverse the expressions in reverse order for (auto it = expr_list.rbegin(); it != expr_list.rend(); it++) { if (it->type().is_handle()) { // Ignore pointer types continue; } // Propagate adjoints it->accept(this); } } } } } void ReverseAccumulationVisitor::accumulate(const Expr &stub, Expr adjoint) { const BaseExprNode *stub_ptr = (const BaseExprNode *)stub.get(); // Trick to avoid NaN in select() clauses: // select(c, x, 0) * y -> select(c, x * y, 0) // x * select(c, y, 0) -> select(c, x * y, 0) // select(c, x, 0) / y -> select(c, x / y, 0) if (adjoint.as() != nullptr) { const Mul *mul_op = adjoint.as(); auto mul_select_with_zero = [&](const Expr &sel, const Expr &other) { const Select *sel_op = sel.as() != nullptr) { adjoint = mul_select_with_zero(mul_op->a, mul_op->b); } else if (mul_op->b.as(); if (is_const_zero(sel_op->true_value)) { return select(sel_op->condition, sel_op->true_value, sel_op->false_value / other); } if (is_const_zero(sel_op->false_value)) { return select(sel_op->condition, sel_op->true_value / other, sel_op->false_value); } return sel * other; }; if (div_op->a.as() != nullptr) { const Select *sel_op = adjoint.as