https://github.com/halide/Halide
Revision d76970aa081df7d30b43a22295b02be759aae93c authored by Steven Johnson on 09 February 2021, 22:32:19 UTC, committed by Steven Johnson on 09 February 2021, 22:32:19 UTC
1 parent fe0888b
Raw File
Tip revision: d76970aa081df7d30b43a22295b02be759aae93c authored by Steven Johnson on 09 February 2021, 22:32:19 UTC
Fix apps/HelloPyTorch
Tip revision: d76970a
CodeGen_PyTorch.h
#ifndef HALIDE_CODEGEN_PYTORCH_H
#define HALIDE_CODEGEN_PYTORCH_H

/** \file
 *
 * Defines an IRPrinter that emits C++ code that:
 * 1. wraps PyTorch's C++ tensor into Halide * buffers,
 * 2. calls the corresponding Halide operator.
 * 3. maps the output buffer back to a PyTorch tensor.
 *
 * The generated code checks for runtime errors and raises PyTorch exception
 * accordingly. It also makes sure the GPU device and stream are consistent when
 * the PyTorch input, when applicable.
 */

#include "IRPrinter.h"

namespace Halide {

class Module;

namespace Internal {

struct LoweredFunc;

/** This class emits C++ code to wrap a Halide pipeline so that it can
 * be used as a C++ extension operator in PyTorch.
 */
class CodeGen_PyTorch : public IRPrinter {
public:
    CodeGen_PyTorch(std::ostream &dest);
    ~CodeGen_PyTorch() override = default;

    /** Emit the PyTorch C++ wrapper for the Halide pipeline. */
    void compile(const Module &module);

    static void test();

private:
    void compile(const LoweredFunc &func, bool is_cuda);
};

}  // namespace Internal
}  // namespace Halide

#endif  // HALIDE_CODEGEN_PYTORCH_H
back to top