https://github.com/halide/Halide
Revision 41704275655165d69a147a502b62fe296eb72be1 authored by Andrew Adams on 02 September 2020, 22:32:24 UTC, committed by Andrew Adams on 02 September 2020, 22:32:24 UTC
1 parent a054a91
Raw File
Tip revision: 41704275655165d69a147a502b62fe296eb72be1 authored by Andrew Adams on 02 September 2020, 22:32:24 UTC
Remove dead split
Tip revision: 4170427
Derivative.cpp
#include <cmath>
#include <iostream>
#include <set>

#include "Associativity.h"
#include "BoundaryConditions.h"
#include "CSE.h"
#include "Debug.h"
#include "Derivative.h"
#include "DerivativeUtils.h"
#include "Error.h"
#include "ExprUsesVar.h"
#include "FindCalls.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "RealizationOrder.h"
#include "Simplify.h"
#include "Solve.h"
#include "Substitute.h"

namespace Halide {

using std::map;
using std::pair;
using std::set;
using std::string;
using std::vector;
using FuncKey = Derivative::FuncKey;

namespace Internal {
namespace {

bool is_float_extern(const string &op_name,
                     const string &func_name) {
    return op_name == (func_name + "_f16") ||
           op_name == (func_name + "_f32") ||
           op_name == (func_name + "_f64");
};

/** Compute derivatives through reverse accumulation
 */
class ReverseAccumulationVisitor : public IRVisitor {
public:
    void propagate_adjoints(const Func &output,
                            const Func &adjoint,
                            const Region &output_bounds);

    map<FuncKey, Func> get_adjoint_funcs() const {
        return adjoint_funcs;
    }

protected:
    void visit(const IntImm *) override;
    void visit(const UIntImm *) override;
    void visit(const FloatImm *) override;
    void visit(const StringImm *) override;
    void visit(const Cast *op) override;
    void visit(const Variable *op) override;
    void visit(const Add *op) override;
    void visit(const Sub *op) override;
    void visit(const Mul *op) override;
    void visit(const Div *op) override;
    void visit(const Mod *op) override;
    void visit(const Min *op) override;
    void visit(const Max *op) override;
    void visit(const EQ *op) override;
    void visit(const NE *op) override;
    void visit(const LT *op) override;
    void visit(const LE *op) override;
    void visit(const GT *op) override;
    void visit(const GE *op) override;
    void visit(const And *) override;
    void visit(const Or *) override;
    void visit(const Not *) override;
    void visit(const Select *op) override;
    void visit(const Let *op) override;
    void visit(const Call *op) override;
    void visit(const Load *op) override {
        internal_error << "Encounter unexpected expression \"Load\" when differentiating.";
    }
    void visit(const Ramp *op) override {
        internal_error << "Encounter unexpected expression \"Ramp\" when differentiating.";
    }
    void visit(const Broadcast *op) override {
        internal_error << "Encounter unexpected expression \"Broadcast\" when differentiating.";
    }
    void visit(const Shuffle *op) override {
        internal_error << "Encounter unexpected expression \"Shuffle\" when differentiating.";
    }
    void visit(const VectorReduce *op) override {
        internal_error << "Encounter unexpected expression \"VectorReduce\" when differentiating.";
    }
    void visit(const LetStmt *op) override {
        internal_error << "Encounter unexpected statement \"LetStmt\" when differentiating.";
    }
    void visit(const AssertStmt *op) override {
        internal_error << "Encounter unexpected statement \"AssertStmt\" when differentiating.";
    }
    void visit(const ProducerConsumer *op) override {
        internal_error << "Encounter unexpected statement \"ProducerConsumer\" when differentiating.";
    }
    void visit(const For *op) override {
        internal_error << "Encounter unexpected statement \"For\" when differentiating.";
    }
    void visit(const Store *op) override {
        internal_error << "Encounter unexpected statement \"Store\" when differentiating.";
    }
    void visit(const Provide *op) override {
        internal_error << "Encounter unexpected statement \"Provide\" when differentiating.";
    }
    void visit(const Allocate *op) override {
        internal_error << "Encounter unexpected statement \"Allocate\" when differentiating.";
    }
    void visit(const Free *op) override {
        internal_error << "Encounter unexpected statement \"Free\" when differentiating.";
    }
    void visit(const Realize *op) override {
        internal_error << "Encounter unexpected statement \"Realize\" when differentiating.";
    }
    void visit(const Block *op) override {
        internal_error << "Encounter unexpected statement \"Block\" when differentiating.";
    }
    void visit(const IfThenElse *op) override {
        internal_error << "Encounter unexpected statement \"IfThenElse\" when differentiating.";
    }
    void visit(const Evaluate *op) override {
        internal_error << "Encounter unexpected statement \"Evaluate\" when differentiating.";
    }
    void visit(const Prefetch *op) override {
        internal_error << "Encounter unexpected statement \"Prefetch\" when differentiating.";
    }
    void visit(const Fork *op) override {
        internal_error << "Encounter unexpected statement \"Fork\" when differentiating.";
    }
    void visit(const Acquire *op) override {
        internal_error << "Encounter unexpected statement \"Acquire\" when differentiating.";
    }
    void visit(const Atomic *op) override {
        internal_error << "Encounter unexpected statement \"Atomic\" when differentiating.";
    }

private:
    void accumulate(const Expr &stub, Expr adjoint);

    void propagate_halide_function_call(
        Expr adjoint,
        const std::string &name,             // called function name
        const FunctionPtr &func_ptr,         // pointer to halide function, is null if this is a call to buffer or param
        const std::vector<Expr> &call_args,  // call arguments
        int value_index,                     // which element in the tuple
        const Type &type                     // return type of the called function
    );

    // For each expression, we store the accumulated adjoints expression
    map<const BaseExprNode *, Expr> expr_adjoints;
    // For each function and each update, we store the accumulated adjoints func
    map<FuncKey, Func> adjoint_funcs;
    // Let variables and their mapping
    map<string, Expr> let_var_mapping;
    vector<string> let_variables;
    // Bounds of functions
    map<string, Box> func_bounds;
    // Current function that scatters its adjoints to its dependencies
    Func current_func;
    // Current update of the function
    int current_update_id;
    // We compute the derivatives in several passes.
    // Sometimes we don't want to propagate through Halide function calls
    bool is_forward_overwrite_detection_phase;
    bool is_self_referencing_phase;
    // Is the current function update a non overwriting scan?
    bool is_current_non_overwriting_scan;
    // A temporary flag for checking the derivatives
    // to self reference of a Halide function is 1 or not
    // Used in forward overwrite detection phase
    Tuple self_reference_adjoint = Tuple(Expr());
    vector<vector<Expr>> self_reference_args;
};

void ReverseAccumulationVisitor::propagate_adjoints(
    const Func &output,
    const Func &adjoint,
    const Region &output_bounds) {
    // Topologically sort the functions
    map<string, Function> env = find_transitive_calls(output.function());
    vector<string> order =
        realization_order({output.function()}, env).first;
    vector<Func> funcs;
    funcs.reserve(order.size());
    // Internal::debug(0) << "Sorted Func list:\n";
    // for (const auto &func_name : order) {
    //     Internal::debug(0) << "  . " << func_name << "\n";
    // }
    for (const auto &func_name : order) {
        funcs.emplace_back(env[func_name]);
    }
    internal_assert(!funcs.empty());

    // If the derivatives depend on an in-place overwrite,
    // and the self reference adjoint is not 0 or 1,
    // throws an error to the users.
    // For example:
    //
    // 1.
    // f(x) = g(x)
    // f(x) = f(x) * f(x)
    // f'(x) depends on first f(x)
    //
    // 2.
    // f(x) = 0
    // f(x) = 2 * f(x) + g(r.x)
    // g'(r.x) depends on intermediate f'(x)
    //
    // The following is fine because the self reference adjoint is 1:
    // f(x) = f(x) + g(r.x)
    // (when it's 1 all instances of f(x) have the same adjoint)
    //
    // The issue is that the self reference to f makes propagation to g
    // using the wrong adjoints.
    //
    // The user should rewrite the above updates to the following:
    //
    // 1.
    // f_(x, 0) = g(x)
    // f_(x, 1) = f_(x, 0) * f_(x, 0)
    // f(x) = f_(x, 1)
    //
    // 2.
    // f_(x, 0) = 0
    // f_(x, r.x + 1) = 2 * f_(x, r.x) + g(r.x)
    // f(x) = f_(x, r.x.max() + 1)
    //
    // We can do the rewrite for the users automatically, but it requires
    // generating the indirect reference f_, making scheduling these
    // functions extremely difficult.
    is_forward_overwrite_detection_phase = true;
    set<FuncKey> non_overwriting_scans;
    for (int func_id = 0; func_id < (int)funcs.size(); func_id++) {
        const Func &func = funcs[func_id];
        current_func = func;
        // Precompute the left hand side intervals for each update
        // We use this to determine if there's overlaps between the updates
        vector<Box> boxes;
        boxes.reserve(func.num_update_definitions());
        for (int update_id = 0;
             update_id < func.num_update_definitions(); update_id++) {
            const vector<Expr> &args = func.update_args(update_id);
            vector<Interval> intervals;
            intervals.reserve(args.size());
            for (int arg_id = 0; arg_id < (int)args.size(); arg_id++) {
                Scope<Interval> scope;
                ReductionDomain rdom = extract_rdom(args[arg_id]);
                if (rdom.defined()) {
                    const vector<ReductionVariable> &rvars = rdom.domain();
                    for (const auto &r : rvars) {
                        Expr r_max = simplify(r.min + r.extent + 1);
                        scope.push(r.var, Interval(r.min, r_max));
                    }
                }
                Interval interval = bounds_of_expr_in_scope(args[arg_id], scope);
                intervals.push_back(interval);
            }
            boxes.emplace_back(intervals);
        }
        for (int update_id = 0;
             update_id < func.num_update_definitions(); update_id++) {
            // We check for two criteria:
            // 1. We check if the derivatives
            //    depend on previous update, and if that particular
            //    value has been overwritten.
            // 2. For updates of f with reduction variables,
            //    unless the derivatives to self reference is 1 or 0,
            //    we make sure overwritten f' is not used by others.
            //    We conservatively detect this by distinguish two cases:
            //    a. If f' is always never being overwritten for all instances of
            //       the reduction variables
            //    b. Or if f' is never used by others except itself.
            //
            // A few examples:
            //
            // f(x) = f(x) + g(r.x) // good, the self update derivative is 1
            //
            // f(x) = 2 * f(x) // good, although the self update derivative is 2,
            //                    there's no reduction variables
            //
            // f(x) = 2 * f(x) + g(r.x) // bad, f'(x) will be used for updating
            //                             g(r.x) but will be overwritten
            //
            // f(x) = f(x) * f(x) // bad, derivative of f(x) depends on previous value
            //                       which has been overwritten
            //
            // f(x, 0) = ...
            // f(x, 1) = f(x, 0) * f(x, 0) // good, although the derivative depends on
            //                          // previous value, the updates do not overlap
            //
            // f(x, r.x + 1) = 2 * f(x, r.x) + g(r.x) // good,
            //                                      // f' is never overwritten
            //
            // f(x, y) = g(x)
            // f(x, r.x + 1) = f(x, r.x) * f(x, r.x); // bad, the derivatives
            //                                           depend on previous updates
            //
            // f(x, y, 0) = g(x)
            // f(x, r.x + 1, 1) = f(x, r.x, 0) * f(x, r.x, 0); // good
            //
            // f(x, r.x + 1, r.y + 1) = 2 * f(x, r.x, r.y) + g(r.x) // good
            //
            // f(x, r.x + 1, r.x + r.y + 1) = 2 * f(x, r.x, r.y) + g(r.x) // bad

            vector<Expr> zeros;
            Tuple rhs_tuple = func.values();
            zeros.reserve(rhs_tuple.size());
            for (int i = 0; i < (int)rhs_tuple.size(); i++) {
                zeros.push_back(make_zero(rhs_tuple[i].type()));
            }
            self_reference_adjoint = Tuple(zeros);
            self_reference_args.clear();
            // Checking 1. here:
            // Take the derivative at expression level, the results are
            // stored in expr_adjoints
            vector<Expr> expr_list;
            Tuple update_tuple = func.update_values(update_id);
            vector<const BaseExprNode *> output_exprs;
            const vector<Expr> &update_tuple_vector = update_tuple.as_vector();
            for (const auto &expr : update_tuple_vector) {
                vector<Expr> value_expr_list = sort_expressions(expr);
                expr_list.insert(expr_list.end(),
                                 value_expr_list.begin(), value_expr_list.end());
                output_exprs.push_back((const BaseExprNode *)expr_list.back().get());
            }

            // TODO: replace let_var_mapping with Scope
            // Gather let variables
            let_var_mapping.clear();
            let_variables.clear();
            for (auto it = expr_list.begin(); it != expr_list.end(); it++) {
                Expr expr = *it;
                if (expr.get()->node_type == IRNodeType::Let) {
                    const Let *op = expr.as<Let>();
                    // Assume Let variables are unique
                    internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end());
                    let_var_mapping[op->name] = op->value;
                    let_variables.push_back(op->name);
                }
            }

            // Set the output adjoint to 1
            // We're not really propagating adjoints, just checking if there's
            // self references
            for (int i = 0; i < (int)output_exprs.size(); i++) {
                expr_adjoints[output_exprs[i]] = 1.f;
            }

            // Traverse the expressions in reverse order
            for (auto it = expr_list.rbegin(); it != expr_list.rend(); it++) {
                if (it->type().is_handle()) {
                    // Ignore pointer types
                    continue;
                }
                it->accept(this);
            }

            auto error = [&]() {
                user_error << "Can't take the gradients of " << func.name() << ", which depend on intermediate values. "
                           << "Use a scan (which saves intermediate results) instead.";
            };

            // For each adjoint expression depositing to a function or image,
            // check if it references to the function
            bool adjoints_used_by_others = false;
            for (const auto &it : expr_adjoints) {
                Expr target_expr(it.first);
                bool is_target_func_or_buffer = false;
                const Call *call_op = target_expr.as<Call>();
                if (call_op != nullptr) {
                    is_target_func_or_buffer =
                        call_op->call_type == Call::Image ||
                        call_op->call_type == Call::Halide;
                }
                Expr expr = it.second;
                if (is_target_func_or_buffer &&
                    is_calling_function(func.name(), expr, let_var_mapping)) {
                    // Self reference might not be bad.
                    // If we carefully avoid overwriting intermediate values,
                    // we can still backprop.
                    // First we check for the pure definition.
                    // If the pure definition depends on any functions or buffers,
                    // there is no hope since we will overwrite something
                    Tuple rhs_tuple = func.values();
                    for (int tuple_id = 0; tuple_id < (int)rhs_tuple.size();
                         tuple_id++) {
                        if (is_calling_function(rhs_tuple[tuple_id], let_var_mapping)) {
                            error();
                        }
                    }
                    // Now we check all previous updates, see if the left hand
                    // side arguments overlap.
                    Box current_box = boxes[update_id];
                    for (int prev_update_id = 0; prev_update_id < update_id;
                         prev_update_id++) {
                        // Gather two boxes from current update and previous update
                        Box prev_box = boxes[prev_update_id];
                        internal_assert(current_box.size() == prev_box.size());
                        // If any of the boxes overlap, we need to throw an error
                        if (boxes_overlap(current_box, prev_box)) {
                            error();
                        }
                    }
                }

                if (is_target_func_or_buffer && call_op->name != func.name()) {
                    adjoints_used_by_others = true;
                }
            }
            expr_adjoints.clear();

            // Checking 2. here:
            bool all_zero_or_one_self_adjoint = true;
            for (int i = 0; i < (int)self_reference_adjoint.size(); i++) {
                if (!is_const(self_reference_adjoint[i], 0) &&
                    !is_const(self_reference_adjoint[i], 1)) {
                    all_zero_or_one_self_adjoint = false;
                    break;
                }
            }
            bool has_reduction_var = !func.rvars(update_id).empty();
            if (!all_zero_or_one_self_adjoint && has_reduction_var) {
                // a. is there any instance of reduction variable such that
                // the self reference update overwrites itself?
                // Or, equivalently, for all possible values of the reduction
                // variables, does the self reference update always
                // reads from/writes to different locations?
                // First we determine the ranges of RDoms for
                // and_condition_over_domain
                Scope<Interval> varying;
                // Loop over lhs & rhs to grab a reduction domain
                ReductionDomain r;
                const vector<Expr> &update_args = func.update_args(update_id);
                for (const Expr &expr : update_args) {
                    r = extract_rdom(expr);
                    if (r.defined()) {
                        break;
                    }
                }
                if (!r.defined()) {
                    for (int tuple_id = 0; tuple_id < (int)update_tuple.size();
                         tuple_id++) {
                        r = extract_rdom(update_tuple[tuple_id]);
                        if (r.defined()) {
                            break;
                        }
                    }
                }
                internal_assert(r.defined());
                // Go over all self reference call arguments
                bool is_not_overwriting = true;
                for (const vector<Expr> &self_ref_args : self_reference_args) {
                    internal_assert(self_ref_args.size() == update_args.size());
                    Expr not_overwriting_cond = const_false();
                    for (int arg_id = 0; arg_id < (int)self_ref_args.size(); arg_id++) {
                        // Are the read from/write to arguments always different?
                        not_overwriting_cond = simplify(not_overwriting_cond ||
                                                        (self_ref_args[arg_id] != update_args[arg_id]));
                    }
                    not_overwriting_cond = and_condition_over_domain(
                        not_overwriting_cond, varying);
                    // Needs to be true for all self reference
                    is_not_overwriting = is_not_overwriting &&
                                         can_prove(not_overwriting_cond);
                }

                // b. Even if the derivative is overwritten, as long as
                // we don't use it in this update we are good.
                // Otherwise we throw an error
                if (!is_not_overwriting && adjoints_used_by_others) {
                    error();
                }

                if (is_not_overwriting) {
                    // This is a non overwriting scan, let's remember it
                    non_overwriting_scans.insert(FuncKey{func.name(), update_id});
                }
            }
        }
    }
    is_forward_overwrite_detection_phase = false;

    // Bounds inference
    Box output_box;
    for (const auto &p : output_bounds) {
        // Convert from min,extent to min,max
        output_box.push_back(Interval(p.min, p.min + p.extent));
    }
    func_bounds = inference_bounds(output, output_box);
    for (const auto &it : func_bounds) {
        const Box &bounds = it.second;
        for (int d = 0; d < (int)bounds.size(); d++) {
            user_assert(bounds[d].is_bounded()) << "Access to function or buffer " << it.first << " at dimension " << d << " is not bounded. "
                                                << "We can only differentiate bounded accesses.\n";
        }
    }

    // Create a stub for each function and each update to accumulate adjoints.
    for (int func_id = 0; func_id < (int)funcs.size(); func_id++) {
        const Func &func = funcs[func_id];
        for (int update_id = -1; update_id < func.num_update_definitions(); update_id++) {
            Func adjoint_func(func.name() + "_" + std::to_string(update_id + 1) + "_d_def__");
            bool is_final_output = func_id == (int)funcs.size() - 1 &&
                                   update_id == func.num_update_definitions() - 1;
            vector<Var> args = func.args();
            for (auto &arg : args) {
                if (arg.is_implicit()) {
                    // Replace implicit variables with non implicit ones
                    arg = Var();
                }
            }
            if (is_final_output) {
                adjoint_func(args) = adjoint(args);
            } else {
                // Initialize to 0
                if (func.values().size() == 1) {
                    adjoint_func(args) = make_zero(func.values()[0].type());
                } else {
                    vector<Expr> init(func.values().size());
                    for (int i = 0; i < (int)init.size(); i++) {
                        init[i] = make_zero(func.values()[i].type());
                    }
                    adjoint_func(args) = Tuple(init);
                }
            }
            FuncKey func_key{func.name(), update_id};
            internal_assert(adjoint_funcs.find(func_key) == adjoint_funcs.end());
            adjoint_funcs[func_key] = adjoint_func;
        }
    }
    // Also create stubs for buffers referenced by the functions
    map<string, BufferInfo> called_buffers_or_param;
    for (int func_id = 0; func_id < (int)funcs.size(); func_id++) {
        const Func &func = funcs[func_id];
        map<string, BufferInfo> buffers = find_buffer_param_calls(func);
        called_buffers_or_param.insert(buffers.begin(), buffers.end());
    }
    for (const auto &it : called_buffers_or_param) {
        // Replace all the dots in the function names to make it legal.
        Func adjoint_func(replace_all(it.first, ".", "_") + "_d__");
        vector<Var> args(it.second.dimension);
        adjoint_func(args) = make_zero(it.second.type);
        FuncKey func_key{it.first, -1};
        if (adjoint_funcs.find(func_key) != adjoint_funcs.end()) {
            user_error << "Naming conflict between buffer/parameters and function:" << it.first << "\n";
        }
        adjoint_funcs[func_key] = adjoint_func;
    }

    // Traverse functions from producers to consumers for reverse accumulation
    for (int func_id = funcs.size() - 1; func_id >= 0; func_id--) {
        const Func &func = funcs[func_id];
        current_func = func;

        FuncKey func_key{func.name(), func.num_update_definitions() - 1};
        // Traverse from the last update to first
        for (int update_id = func.num_update_definitions() - 1;
             update_id >= -1; update_id--) {
            current_update_id = update_id;
            FuncKey func_key{func.name(), update_id};
            Func adjoint_func = adjoint_funcs[func_key];
            internal_assert(func_bounds.find(func.name()) != func_bounds.end());
            // The propagation of adjoints to self reference goes to
            // current update instead of previous if it's a non overwriting scan
            is_current_non_overwriting_scan = false;
            if (update_id >= 0) {
                auto it = non_overwriting_scans.find(func_key);
                if (it != non_overwriting_scans.end()) {
                    is_current_non_overwriting_scan = true;
                }
            }

            // Initialize the next adjoint function by
            // propagating the adjoints to next update
            // Example:
            // f(x) = ...
            // f(1) = ... <- we're here
            // We have an adjoint for f(1) defined over the whole support of f
            // Now we want to initialize for the f(x) update
            // Need to propagate back to all x while masking 1
            // x -> next_args
            // 1 -> update_args
            auto mask_previous_update = [&]() {
                FuncKey prev_func_key{func.name(), update_id - 1};
                Func &prev_adjoint_func = adjoint_funcs[prev_func_key];
                vector<Var> prev_args = prev_adjoint_func.args();
                vector<Expr> update_args = func.update_args(update_id);
                // Replace implicit variables
                for (auto &arg : update_args) {
                    set<string> implicit_variables =
                        find_implicit_variables(arg);
                    for (const auto &var : implicit_variables) {
                        arg = substitute(var, prev_args[Var::implicit_index(var)], arg);
                    }
                }
                // Check if prev_args are the same as update_args
                // If they are the same simply set everything to zero
                bool is_noop = true;
                for (int i = 0; i < (int)prev_args.size(); i++) {
                    const Variable *update_var = update_args[i].as<Variable>();
                    if (update_var == nullptr || prev_args[i].name() != update_var->name) {
                        is_noop = false;
                    }
                }
                prev_adjoint_func = Func(prev_adjoint_func.name());
                if (!is_noop) {
                    // f'(x) = adjoint
                    prev_adjoint_func(prev_args) =
                        adjoint_funcs[func_key](prev_args);
                    if (func.values().size() == 1) {
                        Type type = func.values()[0].type();
                        prev_adjoint_func(update_args) = make_zero(type);
                    } else {
                        vector<Expr> init(func.values().size());
                        for (int i = 0; i < (int)init.size(); i++) {
                            init[i] = make_zero(func.values()[i].type());
                        }
                        prev_adjoint_func(update_args) = Tuple(init);
                    }
                } else {
                    if (func.values().size() == 1) {
                        Type type = func.values()[0].type();
                        prev_adjoint_func(prev_args) = make_zero(type);
                    } else {
                        vector<Expr> init(func.values().size());
                        for (int i = 0; i < (int)init.size(); i++) {
                            init[i] = make_zero(func.values()[i].type());
                        }
                        prev_adjoint_func(prev_args) = Tuple(init);
                    }
                }
            };
            if (update_id >= 0 && !is_current_non_overwriting_scan) {
                // Delay the masking if we're keeping track of intermediate values.
                // Since in this case we are propagating to current update,
                // instead of previous update.
                mask_previous_update();
            }

            // Now we want to propagate the derivatives at expression level.
            // We topologically sort the expressions for each value in the tuple.
            vector<Expr> expr_list;
            Tuple rhs_tuple =
                update_id < 0 ? func.values() : func.update_values(update_id);
            vector<const BaseExprNode *> output_exprs;
            const vector<Expr> &rhs_tuple_vector = rhs_tuple.as_vector();
            for (const auto &expr : rhs_tuple_vector) {
                vector<Expr> value_expr_list = sort_expressions(expr);
                expr_list.insert(
                    expr_list.end(), value_expr_list.begin(), value_expr_list.end());
                output_exprs.push_back((const BaseExprNode *)expr_list.back().get());
            }

            // TODO: replace let_var_mapping with Scope
            // Gather let variables
            let_var_mapping.clear();
            let_variables.clear();
            for (auto it = expr_list.begin(); it != expr_list.end(); it++) {
                Expr expr = *it;
                if (expr.get()->node_type == IRNodeType::Let) {
                    const Let *op = expr.as<Let>();
                    // Assume Let variables are unique
                    internal_assert(let_var_mapping.find(op->name) == let_var_mapping.end());
                    let_var_mapping[op->name] = op->value;
                    let_variables.push_back(op->name);
                }
            }

            // Retrieve previously propagated adjoint for the Func,
            // apply it to expression adjoints.
            // f(x) = g(x)
            // d_g(x) = d_f(x) * df/dg
            vector<Expr> update_args;
            if (update_id >= 0) {
                update_args = func.update_args(update_id);
            } else {
                update_args.reserve(func.args().size());
                Func adjoint_func = adjoint_funcs[func_key];
                for (const auto &var : adjoint_func.args()) {
                    update_args.push_back(var);
                }
            }

            // We propagate in two phases, the first phase only propagates
            // to self references, the second phase propagates to the rest.
            {  // First phase
                is_self_referencing_phase = true;
                expr_adjoints.clear();
                if (output_exprs.size() == 1) {
                    expr_adjoints[output_exprs[0]] =
                        (adjoint_funcs[func_key])(update_args);
                } else {
                    for (int i = 0; i < (int)output_exprs.size(); i++) {
                        expr_adjoints[output_exprs[i]] =
                            (adjoint_funcs[func_key])(update_args)[i];
                    }
                }

                // Traverse the expressions in reverse order
                for (auto it = expr_list.rbegin(); it != expr_list.rend(); it++) {
                    if (it->type().is_handle()) {
                        // Ignore pointer types
                        continue;
                    }
                    // Propagate adjoints
                    it->accept(this);
                }
            }
            if (is_current_non_overwriting_scan) {
                // Now, if we detect a non-overwriting scan operation,
                // the update of adjoints goes to the current function.
                // We let the previous adjoint the same as the current one

                FuncKey prev_func_key{func_key.first, func_key.second - 1};
                // Recreate a new adjoint for previous update
                Func prev_adjoint;
                vector<Expr> args;
                args.reserve(adjoint_func.args().size());
                for (const auto &arg : adjoint_func.args()) {
                    args.push_back(arg);
                }
                vector<Expr> calls;
                calls.reserve(rhs_tuple.size());
                for (int i = 0; i < (int)rhs_tuple.size(); i++) {
                    calls.push_back(Call::make(
                        adjoint_funcs[func_key].function(), args, i));
                }
                prev_adjoint(args) = Tuple(calls);
                adjoint_funcs[prev_func_key] = prev_adjoint;
                mask_previous_update();
            }
            {  // Second phase
                is_self_referencing_phase = false;
                expr_adjoints.clear();
                for (int i = 0; i < (int)output_exprs.size(); i++) {
                    expr_adjoints[output_exprs[i]] =
                        Call::make(adjoint_funcs[func_key].function(),
                                   update_args, i);
                }

                int count = 0;
                // Traverse the expressions in reverse order
                for (auto it = expr_list.rbegin(); it != expr_list.rend(); it++) {
                    if (it->type().is_handle()) {
                        // Ignore pointer types
                        continue;
                    }
                    // Propagate adjoints
                    it->accept(this);
                    count++;
                }
            }
        }
    }
}

void ReverseAccumulationVisitor::accumulate(const Expr &stub, Expr adjoint) {
    const BaseExprNode *stub_ptr = (const BaseExprNode *)stub.get();

    // Trick to avoid NaN in select() clauses:
    // select(c, x, 0) * y -> select(c, x * y, 0)
    // x * select(c, y, 0) -> select(c, x * y, 0)
    // select(c, x, 0) / y -> select(c, x / y, 0)
    if (adjoint.as<Mul>() != nullptr) {
        const Mul *mul_op = adjoint.as<Mul>();
        auto mul_select_with_zero = [&](const Expr &sel, const Expr &other) {
            const Select *sel_op = sel.as<Select>();
            if (is_zero(sel_op->true_value)) {
                return select(sel_op->condition,
                              sel_op->true_value, sel_op->false_value * other);
            }
            if (is_zero(sel_op->false_value)) {
                return select(sel_op->condition,
                              sel_op->true_value * other, sel_op->false_value);
            }
            return sel * other;
        };
        if (mul_op->a.as<Select>() != nullptr) {
            adjoint = mul_select_with_zero(mul_op->a, mul_op->b);
        } else if (mul_op->b.as<Select>() != nullptr) {
            adjoint = mul_select_with_zero(mul_op->b, mul_op->a);
        }
    }
    if (adjoint.as<Div>() != nullptr) {
        const Div *div_op = adjoint.as<Div>();
        auto div_select_with_zero = [&](const Expr &sel, const Expr &other) {
            const Select *sel_op = sel.as<Select>();
            if (is_zero(sel_op->true_value)) {
                return select(sel_op->condition,
                              sel_op->true_value, sel_op->false_value / other);
            }
            if (is_zero(sel_op->false_value)) {
                return select(sel_op->condition,
                              sel_op->true_value / other, sel_op->false_value);
            }
            return sel * other;
        };
        if (div_op->a.as<Select>() != nullptr) {
            adjoint = div_select_with_zero(div_op->a, div_op->b);
        }
    }

    if (expr_adjoints.find(stub_ptr) == expr_adjoints.end()) {
        expr_adjoints[stub_ptr] = adjoint;
    } else {
        expr_adjoints[stub_ptr] = expr_adjoints[stub_ptr] + adjoint;
    }
}

void ReverseAccumulationVisitor::visit(const IntImm *op) {
    // Nothing to propagate to
}

void ReverseAccumulationVisitor::visit(const UIntImm *op) {
    // Nothing to propagate to
}

void ReverseAccumulationVisitor::visit(const FloatImm *op) {
    // Nothing to propagate to
}

void ReverseAccumulationVisitor::visit(const StringImm *op) {
    // Nothing to propagate to
}

void ReverseAccumulationVisitor::visit(const Cast *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // d/dx cast(x) = 1.f if op->type is float otherwise 0
    if (op->type.is_float()) {
        accumulate(op->value, cast(op->value.type(), adjoint));
    } else {
        accumulate(op->value, make_zero(op->value.type()));
    }
}

void ReverseAccumulationVisitor::visit(const Variable *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    if (op->param.defined()) {
        // This is a reference to a Parameter, propagate to the corresponding buffer
        propagate_halide_function_call(adjoint, op->param.name(), FunctionPtr(), {}, 0, op->type);
        return;
    }

    // If the variable is a let variable, accumulates adjoints into the content
    auto it = let_var_mapping.find(op->name);
    if (it != let_var_mapping.end()) {
        accumulate(it->second, Let::make(op->name, it->second, adjoint));
    }
}

void ReverseAccumulationVisitor::visit(const Add *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // d/da a + b = 1
    accumulate(op->a, adjoint);
    // d/db a + b = 1
    accumulate(op->b, adjoint);
}

void ReverseAccumulationVisitor::visit(const Sub *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // d/da a - b = 1
    accumulate(op->a, adjoint);
    // d/db a - b = -1
    accumulate(op->b, -adjoint);
}

void ReverseAccumulationVisitor::visit(const Mul *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // d/da a * b = b
    accumulate(op->a, adjoint * op->b);
    // d/db a * b = a
    accumulate(op->b, adjoint * op->a);
}

void ReverseAccumulationVisitor::visit(const Div *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // Trick to avoid NaN in select() clauses: if adjoint is a select with an 0,
    // multiply into it
    if (adjoint.as<Select>() != nullptr) {
        const Select *sel_op = adjoint.as<Select>();
        if (is_zero(sel_op->true_value)) {
            // d/da a / b = 1 / b
            accumulate(op->a, select(sel_op->condition,
                                     sel_op->true_value, sel_op->false_value / op->b));
            // d/db a * b = - a / b^2
            accumulate(op->b, select(sel_op->condition,
                                     sel_op->true_value, -sel_op->false_value * op->a / (op->b * op->b)));
            return;
        }
        if (is_zero(sel_op->false_value)) {
            // d/da a / b = 1 / b
            accumulate(op->a, select(sel_op->condition,
                                     sel_op->true_value / op->b, sel_op->false_value));
            // d/db a * b = - a / b^2
            accumulate(op->b, select(sel_op->condition,
                                     -sel_op->true_value * op->a / (op->b * op->b), sel_op->false_value));
            return;
        }
    }

    // d/da a / b = 1 / b
    accumulate(op->a, adjoint / op->b);
    // d/db a / b = - a / b^2
    accumulate(op->b, -adjoint * op->a / (op->b * op->b));
}

void ReverseAccumulationVisitor::visit(const Mod *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // a % b = a - trunc(a/b) * b
    // d/da = 1
    accumulate(op->a, adjoint);
    // d/db = -trunc(a/b)
    accumulate(op->b, -adjoint * trunc(op->a / op->b));
}

void ReverseAccumulationVisitor::visit(const Min *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // d/da min(a, b) = a <= b ? 1 : 0
    accumulate(op->a,
               select(op->a <= op->b, adjoint, make_zero(adjoint.type())));
    // d/db min(a, b) = b <= a ? 1 : 0
    accumulate(op->b,
               select(op->b <= op->a, adjoint, make_zero(adjoint.type())));
}

void ReverseAccumulationVisitor::visit(const Max *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // d/da max(a, b) = a >= b ? 1 : 0
    accumulate(op->a,
               select(op->a >= op->b, adjoint, make_zero(adjoint.type())));
    // d/db max(a, b) = b >= a ? 1 : 0
    accumulate(op->b,
               select(op->b >= op->a, adjoint, make_zero(adjoint.type())));
}

void ReverseAccumulationVisitor::visit(const EQ *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const NE *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const LT *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const LE *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const GT *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const GE *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const And *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const Or *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the arguments
    accumulate(op->a, make_zero(op->a.type()));
    accumulate(op->b, make_zero(op->b.type()));
}

void ReverseAccumulationVisitor::visit(const Not *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    // Expr adjoint = expr_adjoints[op];

    // output is a boolean, so we should propagate zero to the argument
    accumulate(op->a, make_zero(op->a.type()));
}

void ReverseAccumulationVisitor::visit(const Let *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    accumulate(op->body, adjoint);
}

void ReverseAccumulationVisitor::visit(const Select *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];

    // d/db select(a, b, c) = select(a, 1, 0)
    accumulate(op->true_value,
               select(op->condition, adjoint, make_zero(adjoint.type())));
    // d/dc select(a, b, c) = select(a, 0, 1)
    accumulate(op->false_value,
               select(op->condition, make_zero(adjoint.type()), adjoint));
}

void ReverseAccumulationVisitor::visit(const Call *op) {
    internal_assert(expr_adjoints.find(op) != expr_adjoints.end());
    Expr adjoint = expr_adjoints[op];
    if (op->is_extern()) {
        // Math functions
        if (is_float_extern(op->name, "exp")) {
            // d/dx exp(x) = exp(x)
            accumulate(op->args[0], adjoint * exp(op->args[0]));
        } else if (is_float_extern(op->name, "log")) {
            // d/dx log(x) = 1 / x
            accumulate(op->args[0], adjoint / op->args[0]);
        } else if (is_float_extern(op->name, "sin")) {
            // d/dx sin(x) = cos(x)
            accumulate(op->args[0], adjoint * cos(op->args[0]));
        } else if (is_float_extern(op->name, "asin")) {
            // d/dx asin(x) = 1 / sqrt(1 - x^2)
            Expr one = make_one(op->type);
            accumulate(op->args[0], adjoint / sqrt(one - op->args[0] * op->args[0]));
        } else if (is_float_extern(op->name, "cos")) {
            // d/dx cos(x) = -sin(x)
            accumulate(op->args[0], -adjoint * sin(op->args[0]));
        } else if (is_float_extern(op->name, "acos")) {
            // d/dx acos(x) = - 1 / sqrt(1 - x^2)
            Expr one = make_one(op->type);
            accumulate(op->args[0], -adjoint / sqrt(one - op->args[0] * op->args[0]));
        } else if (is_float_extern(op->name, "tan")) {
            // d/dx tan(x) = 1 / cos(x)^2
            Expr c = cos(op->args[0]);
            accumulate(op->args[0], adjoint / (c * c));
        } else if (is_float_extern(op->name, "atan")) {
            // d/dx atan(x) = 1 / (1 + x^2)
            Expr one = make_one(op->type);
            accumulate(op->args[0], adjoint / (one + op->args[0] * op->args[0]));
        } else if (is_float_extern(op->name, "atan2")) {
            Expr x2y2 = op->args[0] * op->args[0] + op->args[1] * op->args[1];
            // d/dy atan2(y, x) = x / (x^2 + y^2)
            accumulate(op->args[0], adjoint * (op->args[1] / x2y2));
            // d/dx atan2(y, x) = -y / (x^2 + y^2)
            accumulate(op->args[1], adjoint * (-op->args[0] / x2y2));
        } else if (is_float_extern(op->name, "sinh")) {
            // d/dx sinh(x) = cosh(x)
            accumulate(op->args[0], adjoint * cosh(op->args[0]));
        } else if (is_float_extern(op->name, "asinh")) {
            // d/dx asin(x) = 1 / sqrt(1 + x^2)
            Expr one = make_one(op->type);
            accumulate(op->args[0], adjoint / sqrt(one + op->args[0] * op->args[0]));
        } else if (is_float_extern(op->name, "cosh")) {
            // d/dx cosh(x) = sinh(x)
            accumulate(op->args[0], adjoint * sinh(op->args[0]));
        } else if (is_float_extern(op->name, "acosh")) {
            // d/dx acosh(x) = 1 / (sqrt(x - 1) sqrt(x + 1)))
            Expr one = make_one(op->type);
            accumulate(op->args[0],
                       adjoint / (sqrt(op->args[0] - one) * sqrt(op->args[0] + one)));
        } else if (is_float_extern(op->name, "tanh")) {
            // d/dx tanh(x) = 1 / cosh(x)^2
            Expr c = cosh(op->args[0]);
            accumulate(op->args[0], adjoint / (c * c));
        } else if (is_float_extern(op->name, "atanh")) {
            // d/dx atanh(x) = 1 / (1 - x^2)
            Expr one = make_one(op->type);
            accumulate(op->args[0], adjoint / (one - op->args[0] * op->args[0]));
        } else if (is_float_extern(op->name, "ceil")) {
            // TODO: d/dx = dirac(n) for n in Z ...
            accumulate(op->args[0], make_zero(op->type));
        } else if (is_float_extern(op->name, "floor")) {
            // TODO: d/dx = dirac(n) for n in Z ...
            accumulate(op->args[0], make_zero(op->type));
        } else if (is_float_extern(op->name, "round")) {
            accumulate(op->args[0], make_zero(op->type));
        } else if (is_float_extern(op->name, "trunc")) {
            accumulate(op->args[0], make_zero(op->type));
        } else if (is_float_extern(op->name, "sqrt")) {
            Expr half = make_const(op->type, 0.5);
            accumulate(op->args[0], adjoint * (half / sqrt(op->args[0])));
        } else if (is_float_extern(op->name, "pow")) {
            Expr one = make_one(op->type);
            accumulate(op->args[0],
                       adjoint * op->args[1] * pow(op->args[0], op->args[1] - one));
            accumulate(op->args[1],
                       adjoint * pow(op->args[0], op->args[1]) * log(op->args[0]));
        } else if (is_float_extern(op->name, "fast_inverse")) {
            // d/dx 1/x = -1/x^2
            Expr inv_x = fast_inverse(op->args[0]);
            accumulate(op->args[0], -adjoint * inv_x * inv_x);
        } else if (is_float_extern(op->name, "fast_inverse_sqrt")) {
            // d/dx x^(-0.5) = -0.5*x^(-1.5)
            Expr inv_sqrt_x = fast_inverse_sqrt(op->args[0]);
            Expr neg_half = make_const(op->type, -0.5);
            accumulate(op->args[0],
                       neg_half * adjoint * inv_sqrt_x * inv_sqrt_x * inv_sqrt_x);
        } else if (op->name == "halide_print") {
            for (const auto &arg : op->args) {
                accumulate(arg, make_zero(op->type));
            }
        } else {
            internal_error << "The derivative of " << op->name << " is not implemented.";
        }
    } else if (op->is_intrinsic()) {
        if (op->is_intrinsic(Call::abs)) {
            accumulate(op->args[0],
                       adjoint * select(op->args[0] > 0,
                                        make_one(op->type), make_const(op->type, -1.0)));
        } else if (op->is_intrinsic(Call::lerp)) {
            // z = x * (1 - w) + y * w
            // dz/dx = 1 - w
            // dz/dy = w
            // dz/dw = y - x
            accumulate(op->args[0], adjoint * (make_one(op->type) - op->args[2]));
            accumulate(op->args[1], adjoint * op->args[2]);
            accumulate(op->args[2], adjoint * (op->args[1] - op->args[0]));
        } else if (op->is_intrinsic(Call::likely)) {
            accumulate(op->args[0], adjoint);
        } else if (op->is_intrinsic(Call::return_second)) {
            accumulate(op->args[0], make_const(op->type, 0.0));
            accumulate(op->args[1], adjoint);
        } else if (op->is_intrinsic(Call::undef)) {
            // do nothing
        } else if (op->is_intrinsic(Call::reinterpret) ||
                   op->is_intrinsic(Call::bitwise_and) ||
                   op->is_intrinsic(Call::bitwise_not) ||
                   op->is_intrinsic(Call::bitwise_or) ||
                   op->is_intrinsic(Call::bitwise_xor) ||
                   op->is_intrinsic(Call::shift_right) ||
                   op->is_intrinsic(Call::shift_left)) {
            // bit manipulations -- these have zero derivatives.
            for (const auto &arg : op->args) {
                accumulate(arg, make_zero(op->type));
            }
        } else {
            user_warning << "Dropping gradients at call to " << op->name << "\n";
            for (const auto &arg : op->args) {
                accumulate(arg, make_zero(op->type));
            }
        }
    } else if (op->call_type == Call::Halide ||
               op->call_type == Call::Image) {  // Halide function call or Halid buffer
        propagate_halide_function_call(adjoint, op->name, op->func, op->args, op->value_index, op->type);
    } else {
        // TODO: let user provide derivatives for external functions
        internal_error << "Unknown call type of operation: " << op->name << "\n";
    }
}

void ReverseAccumulationVisitor::propagate_halide_function_call(
    Expr adjoint, const std::string &name, const FunctionPtr &func_ptr,
    const std::vector<Expr> &call_args, int value_index, const Type &type) {
    if (!type.is_float()) {
        // If the function call does not return continuous output,
        // don't propagate to the function.
        return;
    }
    // Add Let expressions
    adjoint = add_let_expression(adjoint, let_var_mapping, let_variables);
    vector<Expr> lhs = call_args;
    for (int i = 0; i < (int)lhs.size(); i++) {
        lhs[i] = add_let_expression(lhs[i], let_var_mapping, let_variables);
    }
    Expr adjoint_before_canonicalize = adjoint;
    vector<Expr> lhs_before_canonicalize = lhs;

    if (is_forward_overwrite_detection_phase) {
        // Don't need to propagate through function in this phase, we're just
        // checking local derivatives
        // However, we'll accumulate the derivatives to self reference
        // for checking if the self update is harmful for gradients
        if (func_ptr.same_as(current_func.function().get_contents())) {
            self_reference_adjoint[value_index] =
                simplify(self_reference_adjoint[value_index] + adjoint);
            vector<Expr> args = call_args;
            for (int i = 0; i < (int)args.size(); i++) {
                args[i] = add_let_expression(args[i], let_var_mapping, let_variables);
            }
            self_reference_args.push_back(args);
        }
        return;
    }
    if (is_self_referencing_phase) {
        // We want to make sure we propagate to the self reference first.
        // In this phase only self reference is propagated
        if (!func_ptr.same_as(current_func.function().get_contents())) {
            return;
        }
    } else {
        // In the other phase we ignore the self reference
        if (func_ptr.same_as(current_func.function().get_contents())) {
            return;
        }
    }

    // We create different functions for the initial condition and each update
    // When update i uses value from update i-1, we accumulate the
    // adjoints to update i-1
    // If target is the current function itself, send to previous update
    // e.g. f(x) = ...
    //      f(x) = f(x) + 1
    // For the one with non-commutative-associative reductions
    // e.g. f(x, ver) = ...
    //      f(x, 0) = ...
    //      f(x, r.x + 1) = f(x, r.x) * f(x, r.x) + g(r.x)
    // We propagate the whole r.x to the current update.
    // In addition, we propagate the first one d_f(x, 0) to the previous update,
    // by setting all reduction variables to their min() values.
    // Because only f(x, 0) comes from the last update, and
    // the rest belongs to the current update.
    // The above case will be handled by the caller, here we just
    // propagate to current update.
    // TODO: make the comments clearer and clean up the code
    FuncKey func_key;
    if (func_ptr.defined()) {
        Function func(func_ptr);
        func_key = func.name() != current_func.name() ? FuncKey{func.name(), func.updates().size() - 1} : FuncKey{func.name(), current_update_id - 1};
        if (is_current_non_overwriting_scan && is_self_referencing_phase) {
            func_key = FuncKey{func.name(), current_update_id};
        }
    } else {
        func_key = FuncKey{name, -1};
    }
    internal_assert(adjoint_funcs.find(func_key) != adjoint_funcs.end());
    Func &func_to_update = adjoint_funcs[func_key];
    internal_assert(func_to_update.dimensions() == (int)lhs.size());

    bool debug_flag = false;
    adjoint = simplify(common_subexpression_elimination(adjoint));

    if (debug_flag) {
        debug(0) << "current_func:" << current_func.name() << "\n";
        debug(0) << "Scattering to " << name << "\n";
        debug(0) << "lhs is:";
        for (const auto &arg : lhs) {
            debug(0) << " " << arg;
        }
        debug(0) << "\n";
        debug(0) << "adjoint is:" << simplify(adjoint) << "\n";
    }

    // Gather argument & bounds information
    // current_args are the pure variables
    // current_update_args are the actual updates at left hand side
    Func current_adjoint_func =
        adjoint_funcs[FuncKey{current_func.name(), current_update_id}];
    vector<Var> current_args = current_adjoint_func.args();
    const Box &current_bounds = func_bounds[current_func.name()];

    // Replace implicit variables
    for (auto &arg : lhs) {
        set<string> implicit_variables = find_implicit_variables(arg);
        for (const auto &var : implicit_variables) {
            arg = substitute(var, current_args[Var::implicit_index(var)], arg);
        }
    }
    {
        set<string> implicit_variables =
            find_implicit_variables(adjoint);
        for (const auto &var : implicit_variables) {
            adjoint = substitute(
                var, current_args[Var::implicit_index(var)], adjoint);
        }
    }

    // We want to do this:
    // func_to_update(call_args) += adjoint(current_update_args);
    // But call_args can be invalid lhs, need to canonicalize.
    // We canonicalize by first trying to substitute with pure variables.
    // If that fails we will replace variables on lhs with RDoms
    // (general scattering).

    // We try canonicalize the left hand side arguments (call_args)
    // so that it's always x, y, z, ...
    //
    // Given:
    // g(x, y, z) = f(x, y-1, z+1)
    // we get an invalid update:
    // f'(x, y - 1, z + 1) += g'(x, y, z)
    // Goal: rewrite to
    //  ==> f'(x, y, z) += g'(x, y+1, z-1)
    // (below we would call g and g' the "current function" and
    //  we call f and d_f the "function to update")
    //
    // We do this by set up a new set of variables new_args
    // new_args contains a set of variable u0, u1, u2, ...
    // For each left hand side of the update (x, y - 1, z + 1 here),
    // we set up the equations u0 = x, u1 = y - 1, u2 = z + 1.
    // Then we solve for x, y, z and get x = u0, y = u1 + 1, z = u2 - 1
    // We get f'(u0, u1, u2) += g'(u0, u1 + 1, u2 - 1)
    // We then substitute the original variable names back to get
    // f'(x, y, z) += g'(x, x + 1, z - 1)
    //
    // Note that g' would correctly returns 0 outside g's boundary,
    // therefore we do not need to impose bounds on g'.
    // However, consider the case where f'(...) += g'(...) * h(...):
    // we need to clamp h's arguments such that it never goes out of g's domain,
    // otherwise we may get unwanted out-of-bound buffer access.
    //
    // Currently we don't want to mess with system solving.
    // Therefore we gather all arguments that contains multiple pure variables,
    // and invalidate all of them.
    // Inter-dependencies like:
    // g(x, y) = f(x * y, x + y)
    // can't be simplified.
    // In principle this can be inverted by solving a system of equations.
    // In this case we replace x and y with reduction variables that loop
    // through g's bounds
    // i.e.
    // f'(r.x * r.y, r.x + r.y) += g'(r.x, r.y)

    // Prepare a set of new substitution variables for func_to_update
    vector<Var> new_args;
    new_args.reserve(func_to_update.dimensions());
    for (int arg_id = 0; arg_id < (int)func_to_update.dimensions(); arg_id++) {
        new_args.emplace_back(unique_name("u" + std::to_string(arg_id)));
    }

    // Loop over the left hand side of the update, construct equations
    // and invert them.
    vector<bool> canonicalized(lhs.size(), false);
    set<string> canonicalized_vars;
    map<string, Var> lhs_substitute_map;
    for (int arg_id = 0; arg_id < (int)lhs.size(); arg_id++) {
        // Gather all pure variables at call_args[arg_id],
        // substitute them with new_args
        // For now only support single pure variable
        vector<int> variable_ids =
            gather_variables(lhs[arg_id], current_args);
        if (variable_ids.size() != 1) {
            continue;
        }

        int variable_id = variable_ids[0];
        const string &variable = current_args[variable_id].name();
        bool solved;
        Expr result_rhs;
        std::tie(solved, result_rhs) =
            solve_inverse(new_args[arg_id] == lhs[arg_id],
                          new_args[arg_id].name(),
                          variable);
        if (!solved) {
            continue;
        }

        // Substitute all access to variable to clamped version
        Expr clamped_variable = clamp(likely(current_args[variable_id]),
                                      current_bounds[variable_id].min,
                                      current_bounds[variable_id].max);
        adjoint = substitute_rdom_predicate(variable, clamped_variable, adjoint);
        // However we don't want to clamp the access to adjoint function (we need
        // it to return 0 outside of its bounds). We replace the corresponding
        // clamped argument back with the pure variable. It is safe to do
        // so because pure variable in Halide's update function can only be appeared
        // unadorned in the same position.
        adjoint = substitute_call_arg_with_pure_arg(current_adjoint_func,
                                                    variable_id,
                                                    adjoint);

        // Replace pure variable with the reverse.
        // Make sure to also substitute predicates.
        adjoint = substitute_rdom_predicate(variable, result_rhs, adjoint);

        // Since we successfully invert, the left hand side becomes new_args
        lhs[arg_id] = new_args[arg_id];
        // Record that we successfully invert, for those we fail
        // we need to perform general scattering.
        canonicalized[arg_id] = true;
        canonicalized_vars.insert(variable);
        lhs_substitute_map[variable] = new_args[arg_id];
    }

    // Consider the following case:
    // f(x, y) = ...
    // k(n) = f(g(n), n)
    // When we update d_f, the second n would be replaced by y.
    // We need to make sure we also update the call argument to g.
    // Adjoint is automatically handled in the loop above.
    for (int i = 0; i < (int)lhs.size(); i++) {
        for (const auto &it : lhs_substitute_map) {
            lhs[i] = substitute(it.first, it.second, lhs[i]);
        }
    }

    // Sometimes the canonicalization above fails.
    // We replace the pure variables inside lhs with RDoms for general scattering
    Region bounds;
    bounds.reserve(current_args.size());
    for (int arg_id = 0; arg_id < (int)current_args.size(); arg_id++) {
        const Interval &interval = current_bounds[arg_id];
        bounds.emplace_back(interval.min, interval.max - interval.min + 1);
    }
    RDom r_bounds(bounds);
    for (int lhs_id = 0; lhs_id < (int)lhs.size(); lhs_id++) {
        if (!canonicalized[lhs_id]) {
            Expr lhs_arg = lhs[lhs_id];
            vector<string> adjoint_args = current_adjoint_func.function().args();
            vector<int> variable_ids = gather_variables(lhs_arg, adjoint_args);
            // For each variable found in lhs_arg, find the corresponding
            // bound (by looping through all variables) and substitute
            // with the bound reduction variable.
            for (int var_id = 0; var_id < (int)variable_ids.size(); var_id++) {
                for (int arg_id = 0; arg_id < (int)current_args.size(); arg_id++) {
                    const string &variable = adjoint_args[variable_ids[var_id]];
                    if (current_args[arg_id].name() == variable &&
                        canonicalized_vars.find(
                            current_args[arg_id].name()) ==
                            canonicalized_vars.end()) {
                        lhs[lhs_id] = substitute(variable,
                                                 r_bounds[arg_id],
                                                 lhs[lhs_id]);
                        adjoint = substitute(variable, r_bounds[arg_id], adjoint);
                        break;
                    }
                }
            }
        }
    }

    // For each free variable on the rhs, replace it with current bounds
    // e.g. we have in forward pass f(x, y) = g(x)
    //      then we would have g'(x) += f'(x, y) by now
    //      now we need to replace y with a reduction variable over f's bound
    //      x is automatically excluded since it's currently
    //      replaced by the new substitution variable e.g. u_0

    // First gather all free variables
    Region bounds_subset;
    vector<int> arg_id_to_substitute;
    bounds_subset.reserve(current_args.size());
    arg_id_to_substitute.reserve(current_args.size());
    for (int arg_id = 0; arg_id < (int)current_args.size(); arg_id++) {
        if (expr_uses_var(adjoint, current_args[arg_id].name())) {
            const Interval &interval = current_bounds[arg_id];
            bounds_subset.emplace_back(
                interval.min, interval.max - interval.min + 1);
            arg_id_to_substitute.push_back(arg_id);
        }
    }

    // Create a new RDom to loop over all free variables
    if (!arg_id_to_substitute.empty()) {
        RDom r(bounds_subset);
        for (int i = 0; i < (int)arg_id_to_substitute.size(); i++) {
            int arg_id = arg_id_to_substitute[i];
            adjoint = substitute(current_args[arg_id].name(), r[i], adjoint);
        }
    }

    // Simplify expressions
    adjoint = simplify(common_subexpression_elimination(adjoint));
    for (int i = 0; i < (int)lhs.size(); i++) {
        lhs[i] = simplify(common_subexpression_elimination(lhs[i]));
    }

    vector<Var> func_to_update_args = func_to_update.args();

    // General scattering simplification rules:
    // For each expression in lhs,
    // check if it is an expression of a single (associative & commutative)
    // rvar and spans the same interval of the function's bound
    // if so we can rewrite it back to pure variables
    // e.g.
    // f(r.x) = g(r.x)
    // => f(x) = g(x)
    //
    // Another common pattern is the reverse of downsampling
    // if we see s * r.x + r.y and r.y has min == 0 and extent == s
    // we simplify them to x and replace all occurrences of r.x by x/4
    // e.g.
    // f(4 * r.x + r.y) = g(r.x) + h(4 * r.x + r.y)
    // => f(x) = g(x/4) + h(x)
    Expr new_adjoint = func_to_update.values().size() == 1 ? (func_to_update(lhs) + adjoint) : (func_to_update(lhs)[value_index] + adjoint);
    vector<Expr> new_adjoint_tuple(func_to_update.values().size(), Expr(0.f));
    new_adjoint_tuple[value_index] = new_adjoint;
    AssociativeOp associative_op = prove_associativity(
        func_to_update.name(), lhs, new_adjoint_tuple);
    if (associative_op.associative() && associative_op.commutative()) {
        for (int i = 0; i < (int)lhs.size(); i++) {
            Expr lhs_arg = substitute_in_all_lets(lhs[i]);
            const Variable *var = lhs_arg.as<Variable>();
            const Add *add = lhs_arg.as<Add>();
            // f(r.x) = ... && r is associative
            // => f(x) = ...
            if (var != nullptr && var->reduction_domain.defined() &&
                var->reduction_domain.split_predicate().empty()) {
                ReductionDomain rdom = var->reduction_domain;
                int rvar_id = -1;
                for (int rid = 0; rid < (int)rdom.domain().size(); rid++) {
                    if (rdom.domain()[rid].var == var->name) {
                        rvar_id = rid;
                        break;
                    }
                }
                internal_assert(rvar_id != -1);
                ReductionVariable rvar = rdom.domain()[rvar_id];
                // Check if the min/max of the rvariable is equal to
                // the target function
                const Box &target_bounds = func_bounds[name];
                Interval t_interval = target_bounds[i];
                t_interval.min = simplify(t_interval.min);
                t_interval.max = simplify(t_interval.max);
                Interval r_interval(simplify(rvar.min),
                                    simplify(rvar.min + rvar.extent - 1));
                if (can_prove(r_interval.min <= t_interval.min &&
                              r_interval.max >= t_interval.max)) {
                    lhs[i] = func_to_update_args[i];
                    Expr clamped_arg = clamp(func_to_update_args[i],
                                             r_interval.min, r_interval.max);
                    // Replace other occurrence of rvar in lhs
                    for (int j = 0; j < (int)lhs.size(); j++) {
                        if (j != i) {
                            lhs[j] = simplify(substitute(
                                rvar.var, clamped_arg, lhs[j]));
                        }
                    }
                    // Take care of boundary condition
                    Expr in_bound = func_to_update_args[i] >= r_interval.min &&
                                    func_to_update_args[i] <= r_interval.max;
                    adjoint = select(in_bound,
                                     simplify(substitute(rvar.var, clamped_arg, adjoint)),
                                     make_zero(adjoint.type()));
                }
                // f(4 * r.x + r.y) = g(r.x) + h(4 * r.x + r.y)
                // => f(x) = g(x/4) + h(x)
            } else if (add != nullptr &&
                       ((add->a.as<Mul>() != nullptr &&
                         add->b.as<Variable>() != nullptr) ||
                        (add->a.as<Variable>() != nullptr &&
                         add->b.as<Mul>() != nullptr))) {
                // Find pattern s * r.x + r.y where r.y.min == 0 && r.y.extent == s
                Expr a = add->a, b = add->b;
                if (add->b.as<Mul>() != nullptr) {
                    // swap so that b is always the Variable
                    internal_assert(add->a.as<Variable>() != nullptr);
                    std::swap(a, b);
                }
                const Mul *mul = a.as<Mul>();
                const Variable *b_var = b.as<Variable>();
                internal_assert(mul != nullptr && b_var != nullptr);
                Expr mul_a = mul->a, mul_b = mul->b;
                if (mul_a.as<Variable>() != nullptr &&
                    mul_a.as<Variable>()->reduction_domain.defined()) {
                    std::swap(mul_a, mul_b);
                }
                const Variable *mul_b_var = mul_b.as<Variable>();
                if (mul_b_var == nullptr || !mul_b_var->reduction_domain.defined()) {
                    continue;
                }
                ReductionDomain b_rdom = b_var->reduction_domain;
                if (!b_rdom.defined()) {
                    continue;
                }

                int rvar_id = -1;
                for (int rid = 0; rid < (int)b_rdom.domain().size(); rid++) {
                    if (b_rdom.domain()[rid].var == b_var->name) {
                        rvar_id = rid;
                        break;
                    }
                }
                internal_assert(rvar_id != -1);
                ReductionVariable rvar = b_rdom.domain()[rvar_id];
                if (!equal(rvar.min, Expr(0)) || !equal(rvar.extent, mul_a)) {
                    continue;
                }

                ReductionDomain mul_b_rdom = mul_b_var->reduction_domain;
                int mulb_rvar_id = -1;
                for (int rid = 0; rid < (int)mul_b_rdom.domain().size(); rid++) {
                    if (mul_b_rdom.domain()[rid].var == mul_b_var->name) {
                        mulb_rvar_id = rid;
                        break;
                    }
                }
                internal_assert(mulb_rvar_id != -1);
                ReductionVariable mulb_rvar = b_rdom.domain()[mulb_rvar_id];

                // Check if the min/max of the s * r.x + r.y is equal to
                // the target function
                const Box &target_bounds = func_bounds[name];
                Interval t_interval = target_bounds[i];
                t_interval.min = simplify(t_interval.min);
                t_interval.max = simplify(t_interval.max);
                Interval r_interval(simplify(mul_a * mulb_rvar.min),
                                    simplify(mul_a * mulb_rvar.extent - 1));

                if (can_prove(r_interval.min <= t_interval.min &&
                              r_interval.max >= t_interval.max)) {
                    // We've finally made sure that the expression has the form we want
                    // Now replace everything
                    // replace s * r.x + r.y with x
                    lhs[i] = func_to_update_args[i];
                    adjoint = substitute(lhs_arg,
                                         func_to_update_args[i],
                                         substitute_in_all_lets(adjoint));
                    // replace r.x with x / s
                    adjoint = substitute(mul_b, func_to_update_args[i] / mul_a, adjoint);
                    adjoint = simplify(adjoint);
                }
            }
        }
    }

    // We can only have one RDom for each update.
    // Therefore we have to merge RDoms on both lhs and rhs
    // To make use of better locality we preserve partial order
    map<string, ReductionVariableInfo> rvar_maps =
        gather_rvariables(adjoint);
    for (const auto &lhs_arg : lhs) {
        map<string, ReductionVariableInfo> maps =
            gather_rvariables(lhs_arg);
        rvar_maps.insert(maps.begin(), maps.end());
    }
    // Original set of reduction variables
    map<string, ReductionVariableInfo> org_rvar_maps =
        gather_rvariables(adjoint_before_canonicalize);
    for (const auto &lhs_arg : lhs_before_canonicalize) {
        map<string, ReductionVariableInfo> maps =
            gather_rvariables(lhs_arg);
        org_rvar_maps.insert(maps.begin(), maps.end());
    }
    // If the update is non-commutative or non-associative, we need to flip the
    // original set of reduction variable
    if (is_current_non_overwriting_scan) {
        // For each lhs
        for (auto &lhs_arg : lhs) {
            // For each original rvar
            for (const auto &it : org_rvar_maps) {
                RVar r(it.second.domain, it.second.index);
                Expr max = simplify(it.second.min + it.second.extent - 1);
                // Replace the reduction with the flipped version
                lhs_arg = substitute(it.first, max - r, lhs_arg);
            }
        }
        // For adjoint
        // For each original rvar
        for (const auto &it : org_rvar_maps) {
            RVar r(it.second.domain, it.second.index);
            Expr max = simplify(it.second.min + it.second.extent - 1);
            // Replace the reduction with the flipped version
            adjoint = substitute(it.first, max - r, adjoint);
        }
    }

    // Order: newly introduced rvar -> original rvar
    vector<ReductionVariableInfo> new_rvar_vec, old_rvar_vec;
    for (const auto &it : rvar_maps) {
        if (org_rvar_maps.find(it.first) == org_rvar_maps.end()) {
            new_rvar_vec.push_back(it.second);
        } else {
            old_rvar_vec.push_back(it.second);
        }
    }

    // Sort by index & domain
    auto cmp_rv = [](const ReductionVariableInfo &rv0,
                     const ReductionVariableInfo &rv1) {
        ReductionDomain::Compare cmp;
        if (cmp(rv0.domain, rv1.domain)) {
            return true;
        } else {
            return rv0.index < rv1.index;
        }
    };
    std::sort(new_rvar_vec.begin(), new_rvar_vec.end(), cmp_rv);
    std::sort(old_rvar_vec.begin(), old_rvar_vec.end(), cmp_rv);
    // Flatten to an array
    vector<string> var_names;
    Region merged_bounds;
    for (const auto &it : new_rvar_vec) {
        var_names.push_back(it.name);
        merged_bounds.emplace_back(it.min, it.extent);
    }
    for (const auto &it : old_rvar_vec) {
        var_names.push_back(it.name);
        merged_bounds.emplace_back(it.min, it.extent);
    }
    // Produce final merged RDom
    RDom merged_r;
    if (!merged_bounds.empty()) {
        merged_r = RDom(merged_bounds);
        // Transfer the predicate from old RDoms to merged RDom
        // Gather the set of RDoms
        set<ReductionDomain, ReductionDomain::Compare> rdoms;
        for (const auto &it : rvar_maps) {
            rdoms.insert(it.second.domain);
        }
        Expr rdom_predicate = Internal::UIntImm::make(UInt(1), 1);
        for (const auto &rdom : rdoms) {
            rdom_predicate = simplify(rdom_predicate && rdom.predicate());
        }
        // Reference to new RDom
        for (int rid = 0; rid < merged_r.dimensions(); rid++) {
            adjoint = substitute(var_names[rid], merged_r[rid], adjoint);
            for (auto &lhs_arg : lhs) {
                lhs_arg = substitute(var_names[rid], merged_r[rid], lhs_arg);
            }
            rdom_predicate = substitute(
                var_names[rid], merged_r[rid], rdom_predicate);
        }
        if (!is_const(rdom_predicate)) {
            for (int arg_id = 0; arg_id < (int)func_to_update_args.size(); arg_id++) {
                // Substitute new_args back to original variables
                rdom_predicate = substitute(new_args[arg_id].name(),
                                            func_to_update_args[arg_id], rdom_predicate);
            }
            merged_r.where(rdom_predicate);
        }
    }

    // Substitute new_args back to original variables
    for (int arg_id = 0; arg_id < (int)func_to_update_args.size(); arg_id++) {
        for (auto &lhs_arg : lhs) {
            lhs_arg = substitute(new_args[arg_id].name(),
                                 func_to_update_args[arg_id], lhs_arg);
        }
        adjoint = substitute_rdom_predicate(
            new_args[arg_id].name(), func_to_update_args[arg_id], adjoint);
    }

    // Simplify expressions
    adjoint = simplify(common_subexpression_elimination(adjoint));
    for (int i = 0; i < (int)lhs.size(); i++) {
        lhs[i] = simplify(common_subexpression_elimination(lhs[i]));
    }

    if (debug_flag) {
        debug(0) << "func_to_update.name():" << func_to_update.name() << "\n";
        debug(0) << "lhs after canonicalization:";
        for (const auto &arg : lhs) {
            debug(0) << " " << arg;
        }
        debug(0) << "\n";
        debug(0) << "adjoint after canonicalization:" << simplify(adjoint) << "\n";
    }

    // Finally we update the function definitions, possibly merge with previous updates
    auto can_merge = [&](Func &func_to_update,
                         const vector<Expr> &lhs) -> bool {
        if (func_to_update.num_update_definitions() == 0) {
            // If lhs are not pure variables we can't merge to pure definition
            for (int i = 0; i < (int)lhs.size(); i++) {
                if (!equal(lhs[i], func_to_update.args()[i])) {
                    return false;
                }
            }
            ReductionDomain rdom = extract_rdom(adjoint);
            // If there are rdoms in adjoint we can't merge
            return !rdom.defined();
        }
        int update_id = func_to_update.num_update_definitions() - 1;
        vector<Expr> prev_lhs =
            func_to_update.update_args(update_id);
        internal_assert(prev_lhs.size() == lhs.size());
        // If previous update has different left hand side, don't merge
        for (int i = 0; i < (int)prev_lhs.size(); i++) {
            if (!equal(lhs[i], prev_lhs[i])) {
                return false;
            }
        }
        // If previous update has a different set of reduction variables,
        // don't merge
        const vector<ReductionVariable> &rvars =
            func_to_update.update(update_id).get_schedule().rvars();
        if (!merged_r.defined()) {
            return rvars.empty();
        }
        if ((int)rvars.size() != merged_r.dimensions()) {
            return false;
        }

        for (int i = 0; i < (int)rvars.size(); i++) {
            if (!equal(rvars[i].min, merged_r[i].min())) {
                return false;
            }
            if (!equal(rvars[i].extent, merged_r[i].extent())) {
                return false;
            }
        }
        return true;
    };
    if (is_self_referencing_phase) {
        // If this is a self reference call, the relation is = instead of +=
        // For example, consider this:
        // f(x) = g(x)
        // f(k(r.x)) += h(r.x)
        // Multiple k(r.x) may correspond to the same index,
        // but they are overwritten in the reduction loop.
        // Therefore we should also overwrite their derivatives
        // by using = instead of +=
        if (!can_merge(func_to_update, lhs)) {
            if (func_to_update.values().size() == 1) {
                func_to_update(lhs) = adjoint;
            } else {
                func_to_update(lhs)[value_index] = adjoint;
            }
        } else {
            Definition &def = func_to_update.num_update_definitions() == 0 ? func_to_update.function().definition() : func_to_update.function().update(func_to_update.num_update_definitions() - 1);
            vector<Expr> &values = def.values();
            ReductionDomain rdom;
            for (const auto &val : values) {
                rdom = extract_rdom(val);
                if (rdom.defined()) {
                    break;
                }
            }
            if (rdom.defined()) {
                internal_assert(func_to_update.num_update_definitions() > 0);
                // Make sure we're using the same set of reduction variables
                for (int i = 0; i < merged_r.dimensions(); i++) {
                    adjoint = substitute(merged_r[i].name(), RVar(rdom, i), adjoint);
                }
            }

            if (values.size() == 1) {
                values[0] = adjoint;
            } else {
                values[value_index] = adjoint;
            }
        }
        return;
    }

    if (!can_merge(func_to_update, lhs)) {
        if (func_to_update.values().size() == 1) {
            func_to_update(lhs) += adjoint;
        } else {
            func_to_update(lhs)[value_index] += adjoint;
        }
    } else {
        Definition &def = func_to_update.num_update_definitions() == 0 ? func_to_update.function().definition() : func_to_update.function().update(func_to_update.num_update_definitions() - 1);
        vector<Expr> &values = def.values();
        ReductionDomain rdom;
        for (const auto &val : values) {
            rdom = extract_rdom(val);
            if (rdom.defined()) {
                break;
            }
        }
        if (rdom.defined()) {
            internal_assert(func_to_update.num_update_definitions() > 0);
            // Make sure we're using the same set of reduction variables
            for (int i = 0; i < merged_r.dimensions(); i++) {
                adjoint = substitute(merged_r[i].name(), RVar(rdom, i), adjoint);
            }
        }

        if (values.size() == 1) {
            values[0] = simplify(values[0] + adjoint);
        } else {
            const Add *add = values[value_index].as<Add>();
            if (add != nullptr &&
                add->b.as<Call>() != nullptr &&
                add->b.as<Call>()->is_intrinsic(Call::undef)) {
                // Sometimes the expression is an undef for the case of a tuple.
                // Make sure we don't include the undefs
                values[value_index] = simplify(add->a + adjoint);
            } else {
                values[value_index] =
                    simplify(values[value_index] + adjoint);
            }
        }
    }
}

}  // namespace
}  // namespace Internal

Func Derivative::operator()(const Func &func, int update_id) const {
    auto it = adjoints.find(FuncKey{func.name(), update_id});
    if (it == adjoints.end()) {
        Internal::debug(1) << "Could not find Func " << func.name() << "\n";
        return Func();
    }
    return it->second;
}

Func Derivative::operator()(const Buffer<> &buffer) const {
    auto it = adjoints.find(FuncKey{buffer.name(), -1});
    if (it == adjoints.end()) {
        Internal::debug(1) << "Could not find Buffer " << buffer.name() << "\n";
        return Func();
    }
    return it->second;
}

Func Derivative::operator()(const Param<> &param) const {
    auto it = adjoints.find(FuncKey{param.name(), -1});
    if (it == adjoints.end()) {
        Internal::debug(1) << "Could not find Param " << param.name() << "\n";
        return Func();
    }
    return it->second;
}

Derivative propagate_adjoints(const Func &output,
                              const Func &adjoint,
                              const Region &output_bounds) {
    user_assert(output.dimensions() == adjoint.dimensions())
        << "output dimensions and adjoint dimensions must match\n";
    user_assert((int)output_bounds.size() == adjoint.dimensions())
        << "output_bounds and adjoint dimensions must match\n";

    Internal::ReverseAccumulationVisitor visitor;
    visitor.propagate_adjoints(output, adjoint, output_bounds);
    // Since the return value of get_adjoint_funcs() is a temporary,
    // we should *not* use std::move.
    return Derivative{visitor.get_adjoint_funcs()};
}

Derivative propagate_adjoints(const Func &output,
                              const Buffer<float> &adjoint) {
    user_assert(output.dimensions() == adjoint.dimensions());
    Region bounds;
    for (int dim = 0; dim < adjoint.dimensions(); dim++) {
        bounds.emplace_back(adjoint.min(dim), adjoint.min(dim) + adjoint.extent(dim) - 1);
    }
    Func adjoint_func = BoundaryConditions::constant_exterior(adjoint, 0.f);
    return propagate_adjoints(output, adjoint_func, bounds);
}

Derivative propagate_adjoints(const Func &output) {
    Func adjoint("adjoint");
    adjoint(output.args()) = Internal::make_one(output.value().type());
    Region output_bounds;
    output_bounds.reserve(output.dimensions());
    for (int i = 0; i < output.dimensions(); i++) {
        output_bounds.push_back({0, 0});
    }
    return propagate_adjoints(output, adjoint, output_bounds);
}

}  // namespace Halide
back to top