https://github.com/halide/Halide
Raw File
Tip revision: 52541176253e74467dabc42eeee63d9a62c199f6 authored by Steven Johnson on 20 February 2024, 17:13:06 UTC
HALIDE_VERSION_PATCH -> 1 (for 17.0.1) (#8113)
Tip revision: 5254117
EarlyFree.cpp
#include <map>
#include <utility>

#include "EarlyFree.h"
#include "ExprUsesVar.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "InjectHostDevBufferCopies.h"

namespace Halide {
namespace Internal {
namespace {

using std::string;

class FindLastUse : public IRVisitor {
public:
    string func;
    Stmt last_use;

    FindLastUse(string s)
        : func(std::move(s)) {
    }

private:
    bool in_loop = false;
    Stmt containing_stmt;

    using IRVisitor::visit;

    void visit(const For *loop) override {
        loop->min.accept(this);
        loop->extent.accept(this);
        ScopedValue<bool> old_in_loop(in_loop, true);
        loop->body.accept(this);
    }

    void visit(const Fork *fork) override {
        ScopedValue<bool> old_in_loop(in_loop, true);
        fork->first.accept(this);
        fork->rest.accept(this);
    }

    void visit(const Acquire *acq) override {
        acq->semaphore.accept(this);
        acq->count.accept(this);
        ScopedValue<bool> old_in_loop(in_loop, true);
        acq->body.accept(this);
    }

    void visit(const Load *load) override {
        if (func == load->name) {
            last_use = containing_stmt;
        }
        IRVisitor::visit(load);
    }

    void visit(const Call *call) override {
        if (call->name == func) {
            last_use = containing_stmt;
        }
        IRVisitor::visit(call);
    }

    void visit(const Store *store) override {
        if (func == store->name) {
            last_use = containing_stmt;
        }
        IRVisitor::visit(store);
    }

    void visit(const Variable *var) override {
        if (var->name == func || var->name == func + ".buffer") {
            // Don't free the allocation while a buffer that may refer
            // to it is still in use.
            last_use = containing_stmt;
        }
    }

    void visit(const IfThenElse *op) override {
        // It's a bad idea to inject it in either side of an
        // ifthenelse, so we treat this as being in a loop.
        op->condition.accept(this);
        ScopedValue<bool> old_in_loop(in_loop, true);
        op->then_case.accept(this);
        if (op->else_case.defined()) {
            op->else_case.accept(this);
        }
    }

    void visit(const Block *block) override {
        if (in_loop) {
            IRVisitor::visit(block);
        } else {
            ScopedValue<Stmt> old_containing_stmt(containing_stmt, block->first);
            block->first.accept(this);
            if (block->rest.defined()) {
                containing_stmt = block->rest;
                block->rest.accept(this);
            }
        }
    }

    void visit(const Atomic *op) override {
        if (op->mutex_name == func) {
            last_use = containing_stmt;
        }
        IRVisitor::visit(op);
    }
};

class InjectMarker : public IRMutator {
public:
    string func;
    Stmt last_use;

private:
    bool injected = false;

    using IRMutator::visit;

    Stmt inject_marker(Stmt s) {
        if (injected) {
            return s;
        }
        if (s.same_as(last_use)) {
            injected = true;
            return Block::make(s, Free::make(func));
        } else {
            return mutate(s);
        }
    }

    Stmt visit(const Block *block) override {
        Stmt new_rest = inject_marker(block->rest);
        Stmt new_first = inject_marker(block->first);

        if (new_first.same_as(block->first) &&
            new_rest.same_as(block->rest)) {
            return block;
        } else {
            return Block::make(new_first, new_rest);
        }
    }
};

class InjectEarlyFrees : public IRMutator {
    using IRMutator::visit;

    Stmt visit(const Allocate *alloc) override {
        Stmt stmt = IRMutator::visit(alloc);
        alloc = stmt.as<Allocate>();
        internal_assert(alloc);

        FindLastUse last_use(alloc->name);
        stmt.accept(&last_use);

        if (last_use.last_use.defined()) {
            InjectMarker inject_marker;
            inject_marker.func = alloc->name;
            inject_marker.last_use = last_use.last_use;
            stmt = inject_marker.mutate(stmt);
        } else {
            stmt = Allocate::make(alloc->name, alloc->type, alloc->memory_type,
                                  alloc->extents, alloc->condition,
                                  Block::make(alloc->body, Free::make(alloc->name)),
                                  alloc->new_expr, alloc->free_function, alloc->padding);
        }
        return stmt;
    }
};

}  // namespace

Stmt inject_early_frees(const Stmt &s) {
    InjectEarlyFrees early_frees;
    return early_frees.mutate(s);
}

}  // namespace Internal
}  // namespace Halide
back to top