https://github.com/shader-slang/slang
Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-autodiff-rev.h
// slang-ir-autodiff-rev.h
#pragma once

#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-compiler.h"

#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-transcriber-base.h"
#include "slang-ir-autodiff-propagate.h"
#include "slang-ir-autodiff-unzip.h"
#include "slang-ir-autodiff-transpose.h"

namespace Slang
{

struct IRReverseDerivativePassOptions
{
    // Nothing for now..
};

// The result of function parameter transposition.
// Contains necessary info for future processing in the backward differentation pass.
struct ParameterBlockTransposeInfo
{
    // Parameters that should be in the furture primal function.
    HashSet<IRInst*> primalFuncParams;

    // Parameters that should be in the furture propagate function.
    HashSet<IRInst*> propagateFuncParams;

    // The value with which a primal specific parameter should be replaced in propagate func.
    OrderedDictionary<IRInst*, IRInst*> mapPrimalSpecificParamToReplacementInPropFunc;
    // The insts added that is specific for propagate functions and should be removed

    // from the future primal func.
    List<IRInst*> propagateFuncSpecificPrimalInsts;

    // Write backs to perform at the end of the back-prop function in order to return the
    // computed output derivatives for an inout parameter.
    OrderedDictionary<IRInst*, InstPair> outDiffWritebacks;

    // The dOut parameter representing the result derivative to propagate backwards through.
    IRInst* dOutParam;
};

struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
{
    FuncBodyTranscriptionTaskType diffTaskType;

    // Map that stores the upper gradient given an IRInst*
    Dictionary<IRInst*, List<IRInst*>> upperGradients;
    Dictionary<IRInst*, IRInst*> primalToDiffPair;
    Dictionary<IRInst*, IRInst*> orginalToTranscribed;

    // References to other passes that for reverse-mode transcription.
    DiffTransposePass* diffTransposePass;
    DiffPropagationPass* diffPropagationPass;
    DiffUnzipPass* diffUnzipPass;

    // Allocate space for the passes.
    DiffTransposePass               diffTransposePassStorage;
    DiffPropagationPass             diffPropagationPassStorage;
    DiffUnzipPass                   diffUnzipPassStorage;

    BackwardDiffTranscriberBase(
        FuncBodyTranscriptionTaskType taskType,
        AutoDiffSharedContext* shared,
        DiagnosticSink* inSink)
        : AutoDiffTranscriberBase(shared, inSink)
        , diffTaskType(taskType)
        , diffTransposePassStorage(shared)
        , diffPropagationPassStorage(shared)
        , diffUnzipPassStorage(shared)
        , diffTransposePass(&diffTransposePassStorage)
        , diffPropagationPass(&diffPropagationPassStorage)
        , diffUnzipPass(&diffUnzipPassStorage)
    {}

    // Returns "dp<var-name>" to use as a name hint for parameters.
    // If no primal name is available, returns a blank string.
    // 
    String makeDiffPairName(IRInst* origVar);
        
    IRFuncType* differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermediateType);

    IRType* transcribeParamTypeForPrimalFunc(IRBuilder* builder, IRType* paramType);
    IRType* transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType);

    // Puts parameters into their own block.
    void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func);

    // Transcribe a function definition.
    virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0;

    // Get transcribed function name from original name.
    virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) = 0;

    // Splits and transpose the parameter block.
    // After this operation, the parameter block will contain parameters for both the future
    // primal func and the future propagate func.
    // Additional info is returned in `ParameterBlockTransposeInfo` for future processing such
    // as inserting write-back logic or splitting them into different functions.
    ParameterBlockTransposeInfo splitAndTransposeParameterBlock(
        IRBuilder* builder,
        IRFunc* diffFunc,
        bool isResultDifferentiable);
    
    void writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc);
    
    virtual InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) override;

    InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);

    SlangResult prepareFuncForBackwardDiff(IRFunc* func);

    IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc);

    void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc);

    InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);

    void addTranscribedFuncDecoration(IRBuilder& builder, IRFunc* origFunc, IRFunc* transcribedFunc);

    virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;

    virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override;

    virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) = 0;
    virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) = 0;

    virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
    {
        return kIROp_BackwardDerivativeDecoration;
    }
};

struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase
{
    BackwardDiffPrimalTranscriber(
        AutoDiffSharedContext* shared,
        DiagnosticSink* inSink)
        : BackwardDiffTranscriberBase(
              FuncBodyTranscriptionTaskType::BackwardPrimal, shared, inSink)
    { }

    virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
    virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override;
    virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override
    {
        if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>())
        {
            return backDecor->getBackwardDerivativePrimalFunc();
        }
        return nullptr;
    }
    virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override
    {
        builder->addBackwardDerivativePrimalDecoration(inst, diffFunc);
    }
    virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
    {
        return kIROp_BackwardDerivativePrimalDecoration;
    }
    virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override
    {
        if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
        {
            StringBuilder sbuilder;
            sbuilder << "s_primal_ctx_" << nameHint->getName();
            return builder->getStringValue(sbuilder.getUnownedSlice());
        }
        else
        {
            return builder->getStringValue(String("s_primal_ctx_anonymous").getUnownedSlice());
        }
    }
};

struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
{
    BackwardDiffPropagateTranscriber(
        AutoDiffSharedContext* shared,
        DiagnosticSink* inSink)
        : BackwardDiffTranscriberBase(
              FuncBodyTranscriptionTaskType::BackwardPropagate,
              shared,
              inSink)
    { }
    void generateTrivialDiffFuncFromUserDefinedDerivative(
        IRBuilder* builder,
        IRFunc* primalFunc,
        IRFunc* diffPropFunc,
        IRUserDefinedBackwardDerivativeDecoration* udfDecor);

    virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
    virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override;
    virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override
    {
        if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativePropagateDecoration>())
        {
            return backDecor->getBackwardDerivativePropagateFunc();
        }
        return nullptr;
    }
    virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override
    {
        builder->addBackwardDerivativePropagateDecoration(inst, diffFunc);
    }
    virtual IROp getInterfaceRequirementDerivativeDecorationOp() override
    {
        return kIROp_BackwardDerivativePropagateDecoration;
    }
    virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override
    {
        if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
        {
            StringBuilder sbuilder;
            sbuilder << "s_bwd_prop_" << nameHint->getName();
            return builder->getStringValue(sbuilder.getUnownedSlice());
        }
        else
        {
            return builder->getStringValue(String("s_bwd_prop_anonymous").getUnownedSlice());
        }
    }
};

// A backward derivative function combines both primal + propagate functions and accepts no
// intermediate value input.
struct BackwardDiffTranscriber : BackwardDiffTranscriberBase
{
    BackwardDiffTranscriber(
        AutoDiffSharedContext* shared,
        DiagnosticSink* inSink)
        : BackwardDiffTranscriberBase(
              FuncBodyTranscriptionTaskType::Backward, shared, inSink)
    { }

    virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
    virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
    virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override
    {
        // Don't need to do anything here, the body is generated in transcribeFuncHeader.

        SLANG_UNUSED(builder);
        addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
        return InstPair(primalFunc, diffFunc);
    }
    virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override
    {
        if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativeDecoration>())
        {
            return backDecor->getBackwardDerivativeFunc();
        }
        if (auto backDecor = originalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
        {
            return backDecor->getBackwardDerivativeFunc();
        }
        return nullptr;
    }
    virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override
    {
        builder->addBackwardDerivativeDecoration(inst, diffFunc);
    }
    virtual IRStringLit* getTranscribedFuncName(IRBuilder* builder, IRGlobalValueWithCode* func) override
    {
        if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
        {
            StringBuilder sbuilder;
            sbuilder << "s_bwd_" << nameHint->getName();
            return builder->getStringValue(sbuilder.getUnownedSlice());
        }
        else
        {
            return builder->getStringValue(String("s_bwd_anonymous").getUnownedSlice());
        }
    }
};

}
back to top