https://github.com/shader-slang/slang
Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-autodiff.h
// slang-ir-autodiff-fwd.h
#pragma once

#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-compiler.h"

#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"

namespace Slang
{
template<typename P, typename D>
struct DiffInstPair
{
    P primal;
    D differential;
    DiffInstPair() = default;
    DiffInstPair(P primal, D differential) : primal(primal), differential(differential)
    {}
    HashCode getHashCode() const
    {
        Hasher hasher;
        hasher << primal << differential;
        return hasher.getResult();
    }
    bool operator ==(const DiffInstPair& other) const
    {
        return primal == other.primal && differential == other.differential;
    }
};

typedef DiffInstPair<IRInst*, IRInst*> InstPair;

enum class FuncBodyTranscriptionTaskType
{
    Forward, BackwardPrimal, BackwardPropagate, Backward
};

struct FuncBodyTranscriptionTask
{
    FuncBodyTranscriptionTaskType type;
    IRFunc* originalFunc;
    IRFunc* resultFunc;
};

struct AutoDiffTranscriberBase;

struct DiffTranscriberSet
{
    AutoDiffTranscriberBase* forwardTranscriber = nullptr;
    AutoDiffTranscriberBase* primalTranscriber = nullptr;
    AutoDiffTranscriberBase* propagateTranscriber = nullptr;
    AutoDiffTranscriberBase* backwardTranscriber = nullptr;
};

struct AutoDiffSharedContext
{
    TargetRequest* targetRequest = nullptr;

    IRModuleInst* moduleInst = nullptr;

    // A reference to the builtin IDifferentiable interface type.
    // We use this to look up all the other types (and type exprs)
    // that conform to a base type.
    // 
    IRInterfaceType* differentiableInterfaceType = nullptr;

    // The struct key for the 'Differential' associated type
    // defined inside IDifferential. We use this to lookup the differential
    // type in the conformance table associated with the concrete type.
    // 
    IRStructKey* differentialAssocTypeStructKey = nullptr;

    // The struct key for the witness that `Differential` associated type conforms to
    // `IDifferential`.
    IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;


    // The struct key for the 'zero()' associated type
    // defined inside IDifferential. We use this to lookup the 
    // implementation of zero() for a given type.
    // 
    IRStructKey* zeroMethodStructKey = nullptr;

    // The struct key for the 'add()' associated type
    // defined inside IDifferential. We use this to lookup the 
    // implementation of add() for a given type.
    // 
    IRStructKey* addMethodStructKey = nullptr;

    IRStructKey* mulMethodStructKey = nullptr;

    // Refernce to NullDifferential struct type. These are used
    // as sentinel values for uninitialized existential (interface-typed) 
    // differentials.
    //
    IRStructType* nullDifferentialStructType = nullptr;

    // Reference to the NullDifferential : IDifferentiable witness.
    //
    IRInst* nullDifferentialWitness = nullptr;


    // Modules that don't use differentiable types
    // won't have the IDifferentiable interface type available. 
    // Set to false to indicate that we are uninitialized.
    // 
    bool                                    isInterfaceAvailable = false;

    List<FuncBodyTranscriptionTask>         followUpFunctionsToTranscribe;

    DiffTranscriberSet transcriberSet;

    AutoDiffSharedContext(TargetRequest* target, IRModuleInst* inModuleInst);

private:

    IRInst* findDifferentiableInterface();

    IRStructType *findNullDifferentialStructType();

    IRInst *findNullDifferentialWitness();

    IRStructKey* findDifferentialTypeStructKey()
    {
        return getIDifferentiableStructKeyAtIndex(0);
    }

    IRStructKey* findDifferentialTypeWitnessStructKey()
    {
        return getIDifferentiableStructKeyAtIndex(1);
    }

    IRStructKey* findZeroMethodStructKey()
    {
        return getIDifferentiableStructKeyAtIndex(2);
    }

    IRStructKey* findAddMethodStructKey()
    {
        return getIDifferentiableStructKeyAtIndex(3);
    }

    IRStructKey* findMulMethodStructKey()
    {
        return getIDifferentiableStructKeyAtIndex(4);
    }

    IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index);
};

struct DifferentiableTypeConformanceContext
{
    AutoDiffSharedContext* sharedContext;

    IRGlobalValueWithCode* parentFunc = nullptr;
    OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary;

    IRFunc* existentialDAddFunc = nullptr;

    DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
        : sharedContext(shared)
    {
        // Populate dictionary with null differential type.
        if (sharedContext->nullDifferentialStructType)
            differentiableWitnessDictionary.add(
                sharedContext->nullDifferentialStructType,
                sharedContext->nullDifferentialWitness);
    }

    void setFunc(IRGlobalValueWithCode* func);

    void buildGlobalWitnessDictionary();

    // Lookup a witness table for the concreteType. One should exist if concreteType
    // inherits (successfully) from IDifferentiable.
    // 
    IRInst* lookUpConformanceForType(IRInst* type);

    IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);

    IRType* differentiateType(IRBuilder* builder, IRInst* primalType);

    IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType);

    IRInst* getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType);

    IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType);

    IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType);

    IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
    
    IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType);

    IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);

    IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);

    IRInst* getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);

    IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);

    IRInst* tryExtractConformanceFromInterfaceType(
        IRBuilder* builder,
        IRInterfaceType* interfaceType,
        IRWitnessTable* witnessTable);
    
    List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath(
        IRInterfaceType* idiffType,
        IRInterfaceType* type);

    // Lookup and return the 'Differential' type declared in the concrete type
    // in order to conform to the IDifferentiable interface.
    // Note that inside a generic block, this will be a witness table lookup instruction
    // that gets resolved during the specialization pass.
    // 
    IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
    {
        switch (origType->getOp())
        {
        case kIROp_InterfaceType:
        {
            if (isDifferentiableType(origType))
                return this->sharedContext->differentiableInterfaceType;
            else
                return nullptr;
        }
        case kIROp_ArrayType:
        {
            auto diffElementType = (IRType*)getDifferentialForType(
                builder, as<IRArrayType>(origType)->getElementType());
            if (!diffElementType)
                return nullptr;
            return builder->getArrayType(
                diffElementType,
                as<IRArrayType>(origType)->getElementCount());
        }
        case kIROp_DifferentialPairUserCodeType:
        {
            auto diffPairType = as<IRDifferentialPairTypeBase>(origType);
            auto diffType = getDiffTypeFromPairType(builder, diffPairType);
            auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType);
            return builder->getDifferentialPairUserCodeType((IRType*)diffType, diffWitness);
        }
        default:
            return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
        }
    }

    bool isDifferentiableType(IRType* origType)
    {
        for (; origType;)
        {
            switch (origType->getOp())
            {
            case kIROp_FloatType:
            case kIROp_HalfType:
            case kIROp_DoubleType:
            case kIROp_DifferentialPairType:
            case kIROp_DifferentialPairUserCodeType:
                return true;
            case kIROp_VectorType:
            case kIROp_ArrayType:
            case kIROp_PtrType:
            case kIROp_OutType:
            case kIROp_InOutType:
                origType = (IRType*)origType->getOperand(0);
                continue;
            default:
                return lookUpConformanceForType(origType) != nullptr;
            }
        }
        return false;
    }

    IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
    {
        auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
        return result;
    }

    IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
    {
        auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
        return result;
    }

    IRInst* emitNullDifferential(IRBuilder* builder)
    {
        return builder->emitCallInst(
            sharedContext->nullDifferentialStructType,
            getZeroMethodForType(builder, sharedContext->nullDifferentialStructType),
            List<IRInst*>());
    }

    IRFunc* getOrCreateExistentialDAddMethod();
    
};

struct DifferentialPairTypeBuilder
{
    DifferentialPairTypeBuilder() = default;

    DifferentialPairTypeBuilder(AutoDiffSharedContext* sharedContext) : sharedContext(sharedContext) {}

    IRStructField* findField(IRInst* type, IRStructKey* key);

    IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam);

    IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key);

    IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst);

    IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst);

    IRStructKey* _getOrCreateDiffStructKey();

    IRStructKey* _getOrCreatePrimalStructKey();

    IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType);

    IRInst* lowerDiffPairType(IRBuilder* builder, IRType* originalPairType);

    struct PairStructKey
    {
        IRInst* originalType;
        IRInst* diffType;
    };

    // Cache from `IRDifferentialPairType` to materialized struct type.
    Dictionary<IRInst*, IRInst*> pairTypeCache;

    IRStructKey* globalPrimalKey = nullptr;

    IRStructKey* globalDiffKey = nullptr;

    IRInst* genericDiffPairType = nullptr;

    List<IRInst*> generatedTypeList;

    AutoDiffSharedContext* sharedContext = nullptr;
};

void stripAutoDiffDecorations(IRModule* module);
void stripTempDecorations(IRInst* inst);

bool isNoDiffType(IRType* paramType);

IRInst* lookupForwardDerivativeReference(IRInst* primalFunction);

struct IRAutodiffPassOptions
{
    // Nothing for now...
};

bool processAutodiffCalls(
    TargetRequest* target,
    IRModule*                           module,
    DiagnosticSink*                     sink,
    IRAutodiffPassOptions const&   options = IRAutodiffPassOptions());

bool finalizeAutoDiffPass(TargetRequest* target, IRModule* module);

// Utility methods

void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst);

void stripDerivativeDecorations(IRInst* inst);

bool isBackwardDifferentiableFunc(IRInst* func);

bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst);

bool canTypeBeStored(IRInst* type);

inline bool isRelevantDifferentialPair(IRType* type)
{
    if (as<IRDifferentialPairType>(type))
    {
        return true;
    }
    else if (auto argPtrType = as<IRPtrTypeBase>(type))
    {
        if (as<IRDifferentialPairType>(argPtrType->getValueType()))
        {
            return true;
        }
    }
    return false;
}

IRBlock* getBlock(IRInst* inst);

IRInst* getInstInBlock(IRInst* inst);

UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg);

IRUse* findUniqueStoredVal(IRVar* var);
IRUse* findLatestUniqueWriteUse(IRVar* var);
IRUse* findEarliestUniqueWriteUse(IRVar* var);

bool isDerivativeContextVar(IRVar* var);

bool isDiffInst(IRInst* inst);

bool isDifferentialOrRecomputeBlock(IRBlock* block);

};
back to top