https://github.com/shader-slang/slang
Raw File
Tip revision: 768e62f6c7541439e2edc18dad5fb3846d2e05f9 authored by Yong He on 10 October 2022, 22:59:45 UTC
Support multi-level break + single-return conversion + general inline. (#2436)
Tip revision: 768e62f
slang-ir-diff-jvp.cpp
// slang-ir-diff-jvp.cpp
#include "slang-ir-diff-jvp.h"

#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"

namespace Slang
{

template<typename P, typename D>
struct Pair
{
    P primal;
    D differential;

    Pair(P primal, D differential) : primal(primal), differential(differential)
    {}
};

typedef Pair<IRInst*, IRInst*> InstPair;

struct DifferentiableTypeConformanceContext
{
    Dictionary<IRInst*, IRInst*>    witnessTableMap;

    IRInst*                                 inst = nullptr;

    // A reference to the builtin IDifferentiable interface type.
    // We use this to look up all the other types (and type exprs)
    // that conform to a base type.
    // 
    IRInterfaceType*                        differentiableInterfaceType = nullptr;

    // The struct key for the 'Differential' associated type
    // defined inside IDifferential. We use this to lookup the differential
    // type in the conformance table associated with the concrete type.
    // 
    IRStructKey*                            differentialAssocTypeStructKey = nullptr;
    
    // Modules that don't use differentiable types
    // won't have the IDifferentiable interface type available. 
    // Set to false to indicate that we are uninitialized.
    // 
    bool                                    isInterfaceAvailable = false;

    // For handling generic blocks, we use a parent pointer to allow
    // looking up types in all relevant scopes.
    DifferentiableTypeConformanceContext*   parent = nullptr;

    DifferentiableTypeConformanceContext(DifferentiableTypeConformanceContext* parent, IRInst* inst) : parent(parent), inst(inst)
    {
        if (parent)
        {
            differentiableInterfaceType = parent->differentiableInterfaceType;
            differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey;
            isInterfaceAvailable = parent->isInterfaceAvailable;
        }
        else
        {
            differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
            if (differentiableInterfaceType)
            {
                differentialAssocTypeStructKey = findDifferentialTypeStructKey();

                if (differentialAssocTypeStructKey)
                    isInterfaceAvailable = true;
            }
        }

        if (isInterfaceAvailable)
        {
            // Load all witness tables corresponding to the IDifferentiable interface.
            loadWitnessTablesForInterface(differentiableInterfaceType);
        }
    }

    DifferentiableTypeConformanceContext(IRInst* inst) :
        DifferentiableTypeConformanceContext(nullptr, inst)
    {}

    // Lookup a witness table for the concreteType. One should exist if concreteType
    // inherits (successfully) from IDifferentiable.
    // 
    IRInst* lookUpConformanceForType(IRInst* type)
    {
        SLANG_ASSERT(isInterfaceAvailable);

        if (witnessTableMap.ContainsKey(type))
            return witnessTableMap[type];
        else if (parent)
            return parent->lookUpConformanceForType(type);
        else
            return nullptr;
    }
    
    // Lookup and return the 'Differential' type declared in the concrete type
    // in order to conform to the IDifferentiable interface.
    // Note that inside a generic block, this will be a witness table lookup instruction
    // that gets resolved during the specialization pass.
    // 
    IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
    {
        SLANG_ASSERT(isInterfaceAvailable);

        if (auto conformance = lookUpConformanceForType(origType))
        {
            if (auto witnessTable = as<IRWitnessTable>(conformance))
            {
                for (auto entry : witnessTable->getEntries())
                {
                    if (entry->getRequirementKey() == differentialAssocTypeStructKey)
                        return as<IRType>(entry->getSatisfyingVal());
                }
            }
            else if (auto witnessTableParam = as<IRParam>(conformance))
            {
                return builder->emitLookupInterfaceMethodInst(
                    builder->getTypeKind(),
                    witnessTableParam,
                    differentialAssocTypeStructKey);
            }
        }

        return nullptr;
    }

    private:

    IRInst* findDifferentiableInterface()
    {
        if (auto module = as<IRModuleInst>(inst))
        {
            for (auto globalInst : module->getGlobalInsts())
            {
                // TODO: This seems like a particularly dangerous way to look for an interface.
                // See if we can lower IDifferentiable to a separate IR inst.
                //
                if (globalInst->getOp() == kIROp_InterfaceType && 
                    as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
                {
                    return globalInst;
                }
            }
        }
        return nullptr;
    }

    IRStructKey* findDifferentialTypeStructKey()
    {
        if (as<IRModuleInst>(inst) && differentiableInterfaceType)
        {
            // Assume for now that IDifferentiable has exactly one field: the 'Differential' associated type.
            SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 1);
            if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(0)))
                return as<IRStructKey>(entry->getRequirementKey());
            else
            {
                SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
            }
        }

        return nullptr;
    }

    void loadWitnessTablesForInterface(IRInst* interfaceType)
    {
        
        if (auto module = as<IRModuleInst>(inst))
        {
            for (auto globalInst : module->getGlobalInsts())
            {
                if (globalInst->getOp() == kIROp_WitnessTable &&
                    cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() ==
                        interfaceType)
                {
                    // TODO: Can we have multiple conformances for the same pair of types?
                    // TODO: Can type instrs be duplicated (i.e. two different float types)? And if they are duplicated, can
                    // we supply the dictionary with a custom equality rule that uses 'type1->equals(type2)'
                    witnessTableMap.Add(as<IRWitnessTable>(globalInst)->getConcreteType(), globalInst);
                }
            }
        }
        else if (auto generic = as<IRGeneric>(inst))
        {
            List<IRParam*> typeParams;

            auto genericParam = generic->getFirstParam();
            while (genericParam)
            {
                if (as<IRTypeType>(genericParam->getDataType()))
                {
                    typeParams.add(genericParam);
                }
                else
                    break;
                
                genericParam = genericParam->getNextParam();
            }
            
            UCount tableIndex = 0;
            while (genericParam)
            {
                SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType()));
                if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType()))
                {
                    if (witnessTableType->getConformanceType() == differentiableInterfaceType)
                        witnessTableMap.Add(typeParams[tableIndex], genericParam);
                }
                else
                    break;

                tableIndex += 1;
                genericParam = genericParam->getNextParam();
            }
            
        }

    }

};

struct DifferentialPairTypeBuilder
{
    
    DifferentialPairTypeBuilder(DifferentiableTypeConformanceContext* diffConformanceContext) :
        diffConformanceContext(diffConformanceContext)
    {}

    IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
    {
        if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
        {
            auto primalField = as<IRStructField>(basePairStructType->getFirstChild());
            SLANG_ASSERT(primalField);

            return as<IRFieldExtract>(builder->emitFieldExtract(
                    primalField->getFieldType(),
                    baseInst,
                    primalField->getKey()
                ));
        }
        else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
        {
            if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
            {
                auto primalField = as<IRStructField>(pairStructType->getFirstChild());
                SLANG_ASSERT(primalField);
                
                return as<IRFieldAddress>(builder->emitFieldAddress(
                        builder->getPtrType(primalField->getFieldType()),
                        baseInst,
                        primalField->getKey()
                    ));
            }
        }
        else
        {
            SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
        }
        return nullptr;
    }

    IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
    {
        if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
        {
            auto diffField = as<IRStructField>(basePairStructType->getFirstChild()->getNextInst());
            SLANG_ASSERT(diffField);

            return as<IRFieldExtract>(builder->emitFieldExtract(
                    diffField->getFieldType(),
                    baseInst,
                    diffField->getKey()
                ));
        }
        else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
        {
            if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
            {
                auto diffField = as<IRStructField>(pairStructType->getFirstChild()->getNextInst());
                SLANG_ASSERT(diffField);
                
                return as<IRFieldAddress>(builder->emitFieldAddress(
                        builder->getPtrType(diffField->getFieldType()),
                        baseInst,
                        diffField->getKey()
                    ));
            }
        }
        else
        {
            SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
        }
        return nullptr;
    }
    
    IRStructType* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
    {
        if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
        {
            auto diffPairType = builder->createStructType();

            // Create a keys for the primal and differential fields.
            IRStructKey* origKey = builder->createStructKey();
            builder->addNameHintDecoration(origKey, UnownedTerminatedStringSlice("primal"));
            builder->createStructField(diffPairType, origKey, origBaseType);

            IRStructKey* diffKey = builder->createStructKey();
            builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
            builder->createStructField(diffPairType, diffKey, (IRType*)(diffBaseType));

            return diffPairType;
        }
        return nullptr;
    }

    IRStructType* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
    {
        if (pairTypeCache.ContainsKey(origBaseType))
            return pairTypeCache[origBaseType];

        auto pairType = _createDiffPairType(builder, origBaseType);
        pairTypeCache.Add(origBaseType, pairType);

        return pairType;
    }

    Dictionary<IRType*, IRStructType*> pairTypeCache;

    DifferentiableTypeConformanceContext* diffConformanceContext;

};

struct JVPTranscriber
{

    // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
    // their differential values.
    Dictionary<IRInst*, IRInst*>            instMapD;

    // Cloning environment to hold mapping from old to new copies for the primal
    // instructions.
    IRCloneEnv                              cloneEnv;

    // Diagnostic sink for error messages.
    DiagnosticSink*                         sink;

    // Type conformance information.
    DifferentiableTypeConformanceContext*   diffConformanceContext;

    // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
    DifferentialPairTypeBuilder*            pairBuilder;

    DiagnosticSink* getSink()
    {
        SLANG_ASSERT(sink);
        return sink;
    }

    void mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
    {
        instMapD.Add(origInst, diffInst);
    }

    void mapPrimalInst(IRInst* origInst, IRInst* primalInst)
    {
        if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
        {
            getSink()->diagnose(origInst->sourceLoc,
                Diagnostics::internalCompilerError,
                "inconsistent primal instruction for original");
        }
        else
        {
            cloneEnv.mapOldValToNew[origInst] = primalInst;
        }
    }

    IRInst* lookupDiffInst(IRInst* origInst)
    {
        return instMapD[origInst];
    }

    IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
    {
        return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst;
    }

    bool hasDifferentialInst(IRInst* origInst)
    {
        return instMapD.ContainsKey(origInst);
    }

    IRInst* lookupPrimalInst(IRInst* origInst)
    {
        return cloneEnv.mapOldValToNew[origInst];
    }

    IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
    {
        return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
    }

    bool hasPrimalInst(IRInst* origInst)
    {
        return cloneEnv.mapOldValToNew.ContainsKey(origInst);
    }
    
    IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
    {
        if (!hasDifferentialInst(origInst))
        {
            transcribe(builder, origInst);
            SLANG_ASSERT(hasDifferentialInst(origInst));
        }

        return lookupDiffInst(origInst);
    }

    IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
    {
        if (!hasPrimalInst(origInst))
        {
            transcribe(builder, origInst);
            SLANG_ASSERT(hasPrimalInst(origInst));
        }

        return lookupPrimalInst(origInst);
    }

    IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
    {
        List<IRType*> newParameterTypes;
        IRType* diffReturnType;

        for (UIndex i = 0; i < funcType->getParamCount(); i++)
        {
            auto origType = funcType->getParamType(i);
            if (auto diffPairType = tryGetDiffPairType(builder, origType))
                newParameterTypes.add(diffPairType);
            else
                newParameterTypes.add(origType);
        }

        // Transcribe return type to a pair.
        // This will be void if the primal return type is non-differentiable.
        //
        if (auto returnPairType = tryGetDiffPairType(builder, funcType->getResultType()))
            diffReturnType = returnPairType;
        else
            diffReturnType = builder->getVoidType();

        return builder->getFuncType(newParameterTypes, diffReturnType);
    }

    IRType* differentiateType(IRBuilder* builder, IRType* origType)
    {
        switch (origType->getOp())
        {
            case kIROp_HalfType:
            case kIROp_FloatType:
            case kIROp_DoubleType:
            case kIROp_VectorType:
                return (IRType*)(diffConformanceContext->getDifferentialForType(builder, origType));
            case kIROp_OutType:
                return builder->getOutType(differentiateType(builder, as<IROutType>(origType)->getValueType()));
            case kIROp_InOutType:
                return builder->getInOutType(differentiateType(builder, as<IRInOutType>(origType)->getValueType()));
            default:
                return nullptr;
        }
    }
    
    IRType* tryGetDiffPairType(IRBuilder* builder, IRType* origType)
    {
        // If this is a PtrType (out, inout, etc..), then create diff pair from
        // value type and re-apply the appropropriate PtrType wrapper.
        // 
        if (auto origPtrType = as<IRPtrTypeBase>(origType))
        {   
            if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
                return builder->getPtrType(origType->getOp(), diffPairValueType);
            else 
                return nullptr;
        }

        return pairBuilder->getOrCreateDiffPairType(builder, origType);
    }

    InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
    {
        if (auto diffPairType = tryGetDiffPairType(builder, origParam->getFullType()))
        {
            IRParam* diffPairParam = builder->emitParam(diffPairType);

            auto diffPairVarName = makeDiffPairName(origParam);
            if (diffPairVarName.getLength() > 0)
                builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());

            SLANG_ASSERT(diffPairParam);

            return InstPair(
                pairBuilder->emitPrimalFieldAccess(builder, diffPairParam),
                pairBuilder->emitDiffFieldAccess(builder, diffPairParam));
        }
        
        return InstPair(
            cloneInst(&cloneEnv, builder, origParam),
            nullptr);
    }

    // Returns "d<var-name>" to use as a name hint for variables and parameters.
    // If no primal name is available, returns a blank string.
    // 
    String getJVPVarName(IRInst* origVar)
    {
        if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
        {
            return ("d" + String(namehintDecoration->getName()));
        }

        return String("");
    }

    // Returns "dp<var-name>" to use as a name hint for parameters.
    // If no primal name is available, returns a blank string.
    // 
    String makeDiffPairName(IRInst* origVar)
    {
        if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
        {
            return ("dp" + String(namehintDecoration->getName()));
        }

        return String("");
    }

    InstPair transcribeVar(IRBuilder* builder, IRVar* origVar)
    {
        if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
        {
            IRVar* diffVar = builder->emitVar(diffType);
            SLANG_ASSERT(diffVar);

            auto diffNameHint = getJVPVarName(origVar);
            if (diffNameHint.getLength() > 0)
                builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice());

            return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar);
        }

        return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr);
    }

    InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
    {
        SLANG_ASSERT(origArith->getOperandCount() == 2);

        IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith);
        
        auto origLeft = origArith->getOperand(0);
        auto origRight = origArith->getOperand(1);

        auto primalLeft = findOrTranscribePrimalInst(builder, origLeft);
        auto primalRight = findOrTranscribePrimalInst(builder, origRight);
    
        auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
        auto diffRight = findOrTranscribeDiffInst(builder, origRight);

        auto leftZero = builder->getFloatValue(origLeft->getDataType(), 0.0);
        auto rightZero = builder->getFloatValue(origRight->getDataType(), 0.0);

        if (diffLeft || diffRight)
        {
            diffLeft = diffLeft ? diffLeft : leftZero;
            diffRight = diffRight ? diffRight : rightZero;

            auto resultType = origArith->getDataType();
            switch(origArith->getOp())
            {
            case kIROp_Add:
                return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight));
            case kIROp_Mul:
                return InstPair(primalArith, builder->emitAdd(resultType,
                    builder->emitMul(resultType, diffLeft, primalRight),
                    builder->emitMul(resultType, primalLeft, diffRight)));
            case kIROp_Sub:
                return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight));
            case kIROp_Div:
                return InstPair(primalArith, builder->emitDiv(resultType, 
                    builder->emitSub(
                        resultType,
                        builder->emitMul(resultType, diffLeft, primalRight),
                        builder->emitMul(resultType, primalLeft, diffRight)),
                    builder->emitMul(
                        primalRight->getDataType(), primalRight, primalRight
                    )));
            default:
                getSink()->diagnose(origArith->sourceLoc,
                    Diagnostics::unimplemented,
                    "this arithmetic instruction cannot be differentiated");
            }
        }

        return InstPair(primalArith, nullptr);
    }

    InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
    {
        auto origPtr = origLoad->getPtr();
        
        auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);

        if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
        {
            IRLoad* diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
            SLANG_ASSERT(diffLoad);

            return InstPair(primalLoad, diffLoad);
        }
        return InstPair(primalLoad, nullptr);
    }

    InstPair transcribeStore(IRBuilder* builder, IRStore* origStore)
    {
        IRInst* origStoreLocation = origStore->getPtr();
        IRInst* origStoreVal = origStore->getVal();
        
        auto primalStore = cloneInst(&cloneEnv, builder, origStore);

        auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
        auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);

        // If the stored value has a differential version, 
        // emit a store instruction for the differential parameter.
        // Otherwise, emit nothing since there's nothing to load.
        // 
        if (diffStoreLocation && diffStoreVal)
        {
            IRStore* diffStore = as<IRStore>(
                    builder->emitStore(diffStoreLocation, diffStoreVal));
            SLANG_ASSERT(diffStore);
            
            return InstPair(primalStore, diffStore);
        }

        return InstPair(primalStore, nullptr);
    }

    InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
    {
        IRInst* origReturnVal = origReturn->getVal();
        
        if (auto pairType = tryGetDiffPairType(builder, origReturnVal->getDataType()))
        {   
            IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
        
            IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
            if(!diffReturnVal)
                diffReturnVal = getZeroOfType(builder, origReturnVal->getDataType());

            auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
            IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
            return InstPair(pairReturn, pairReturn);
        }
        else
        {
            // If the differential return value is not available, emit a 
            // void return.
            IRInst* voidReturn = builder->emitReturn();
            return InstPair(voidReturn, voidReturn);
        }
    }

    // Since int/float literals are sometimes nested inside an IRConstructor
    // instruction, we check to make sure that the nested instr is a constant
    // and then return nullptr. Literals do not need to be differentiated.
    //
    InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
    {   
        IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);

        if (as<IRConstant>(origConstruct->getOperand(0)) && origConstruct->getOperandCount() == 1)
            return InstPair(primalConstruct, nullptr);
        else
            getSink()->diagnose(origConstruct->sourceLoc,
                    Diagnostics::unimplemented,
                    "this construct instruction cannot be differentiated");

        return InstPair(primalConstruct, nullptr);
    }

    // Differentiating a call instruction here is primarily about generating
    // an appropriate call list based on whichever parameters have differentials 
    // in the current transcription context.
    // 
    InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
    {   
        if (auto origCallee = as<IRFunc>(origCall->getCallee()))
        {
            
            // Build the differential callee
            IRInst* diffCall = builder->emitJVPDifferentiateInst(
                differentiateFunctionType(builder, as<IRFuncType>(origCallee->getFullType())),
                origCallee);
            
            List<IRInst*> args;
            // Go over the parameter list and create pairs for each input (if required)
            for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
            {
                auto origArg = origCall->getArg(ii);
                auto primalArg = findOrTranscribePrimalInst(builder, origArg);
                SLANG_ASSERT(primalArg);

                auto origType = origArg->getDataType();
                if (auto pairType = tryGetDiffPairType(builder, origType))
                {
                    
                    auto diffArg = findOrTranscribeDiffInst(builder, origArg);

                    // TODO(sai): This part is flawed. Replace with a call to the 
                    // 'zero()' interface method.
                    if (!diffArg)
                        diffArg = getZeroOfType(builder, origType);
                    
                    auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);

                    args.add(diffPair);
                }
                else
                {
                    // Add original/primal argument.
                    args.add(primalArg);
                }
            }
            
            auto callInst = builder->emitCallInst(
                tryGetDiffPairType(builder, origCall->getFullType()),
                diffCall,
                args);
            
            return InstPair(
                pairBuilder->emitPrimalFieldAccess(builder, callInst),
                pairBuilder->emitDiffFieldAccess(builder, callInst));
        }
        else
        {
            // Note that this can only happen if the callee is a result
            // of a higher-order operation. For now, we assume that we cannot
            // differentiate such calls safely.
            // TODO(sai): Should probably get checked in the front-end.
            //
            getSink()->diagnose(origCall->sourceLoc,
                Diagnostics::internalCompilerError,
                "attempting to differentiate unresolved callee");
        }

        return InstPair(nullptr, nullptr);
    }

    InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
    {
        IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle);

        if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
        {
            List<IRInst*> swizzleIndices;
            for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
                swizzleIndices.add(origSwizzle->getElementIndex(ii));
            
            return InstPair(
                primalSwizzle,
                builder->emitSwizzle(
                    differentiateType(builder, origSwizzle->getDataType()),
                    diffBase,
                    origSwizzle->getElementCount(),
                    swizzleIndices.getBuffer()));
        }

        return InstPair(primalSwizzle, nullptr);
    }

    InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
    {
        IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst);

        UCount operandCount = origInst->getOperandCount();

        List<IRInst*> diffOperands;
        for (UIndex ii = 0; ii < operandCount; ii++)
        {
            // If the operand has a differential version, replace the original with the 
            // differential.
            // Otherwise, abandon the differentiation attempt and assume that origInst 
            // cannot (or does not need to) be differentiated.
            // 
            if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr))
                diffOperands.add(diffInst);
            else
                return InstPair(primalInst, nullptr);
        }
        
        return InstPair(
            primalInst, 
            builder->emitIntrinsicInst(
                differentiateType(builder, origInst->getDataType()),
                origInst->getOp(),
                operandCount,
                diffOperands.getBuffer()));
    }

    InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
    {
        switch(origInst->getOp())
        {
            case kIROp_unconditionalBranch:
                auto origBranch = as<IRUnconditionalBranch>(origInst);

                // Branches with extra operands not handled currently.
                if (origBranch->getOperandCount() > 1)
                    break;

                IRInst* diffBranch = nullptr;

                if (auto diffBlock = lookupDiffInst(origBranch->getTargetBlock(), nullptr))
                    diffBranch = builder->emitBranch(as<IRBlock>(diffBlock));

                // For now, every block in the original fn must have a corresponding
                // block to compute both primals and derivatives.
                SLANG_ASSERT(diffBranch);

                return InstPair(diffBranch, diffBranch);
        }

        getSink()->diagnose(
            origInst->sourceLoc,
            Diagnostics::unimplemented,
            "attempting to differentiate unhandled control flow");

        return InstPair(nullptr, nullptr);
    }

    
    InstPair transcribeConst(IRBuilder*, IRInst* origInst)
    {
        switch(origInst->getOp())
        {
            case kIROp_FloatLit:
                return InstPair(origInst, nullptr);
        }

        getSink()->diagnose(
            origInst->sourceLoc,
            Diagnostics::unimplemented,
            "attempting to differentiate unhandled const type");

        return InstPair(nullptr, nullptr);
    }

    // In differential computation, the 'default' differential value is always zero.
    // This is a consequence of differential computing being inherently linear. As a 
    // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
    // 
    IRInst* getZeroOfType(IRBuilder* builder, IRType* type)
    {
        switch (type->getOp())
        {
            case kIROp_FloatType:
            case kIROp_HalfType:
            case kIROp_DoubleType:
                return builder->getFloatValue(type, 0.0);
            case kIROp_IntType:
                return builder->getIntValue(type, 0);
            case kIROp_VectorType:
            {
                IRInst* args[] = {getZeroOfType(builder, as<IRVectorType>(type)->getElementType())};
                return builder->emitIntrinsicInst(
                    type,
                    kIROp_constructVectorFromScalar,
                    1,
                    args);
            }
            default:
                getSink()->diagnose(type->sourceLoc,
                    Diagnostics::internalCompilerError,
                    "could not generate zero value for given type");
                return nullptr;       
        }
    }

    IRInst* transcribe(IRBuilder* builder, IRInst* origInst)
    {
        InstPair pair = transcribeInst(builder, origInst);

        if (auto primalInst = pair.primal)
        {
            mapPrimalInst(origInst, pair.primal);

            mapDifferentialInst(origInst, pair.differential);
            return pair.differential;
        }

        getSink()->diagnose(origInst->sourceLoc,
                    Diagnostics::internalCompilerError,
                    "failed to transcibe instruction");
        return nullptr;
    }

    InstPair transcribeInst(IRBuilder* builder, IRInst* origInst)
    {
        // Handle common operations
        switch (origInst->getOp())
        {
        case kIROp_Param:
            return transcribeParam(builder, as<IRParam>(origInst));

        case kIROp_Var:
            return transcribeVar(builder, as<IRVar>(origInst));

        case kIROp_Load:
            return transcribeLoad(builder, as<IRLoad>(origInst));

        case kIROp_Store:
            return transcribeStore(builder, as<IRStore>(origInst));

        case kIROp_Return:
            return transcribeReturn(builder, as<IRReturn>(origInst));

        case kIROp_Add:
        case kIROp_Mul:
        case kIROp_Sub:
        case kIROp_Div:
            return transcribeBinaryArith(builder, origInst);

        case kIROp_Construct:
            return transcribeConstruct(builder, origInst);
        
        case kIROp_Call:
            return transcribeCall(builder, as<IRCall>(origInst));
        
        case kIROp_swizzle:
            return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
        
        case kIROp_constructVectorFromScalar:
            return transcribeByPassthrough(builder, origInst);

        case kIROp_unconditionalBranch:
        case kIROp_conditionalBranch:
            return transcribeControlFlow(builder, origInst);

        case kIROp_FloatLit:
            return transcribeConst(builder, origInst);

        }
    
        // If none of the cases have been hit, check if the instruction is a
        // type.
        // For now we don't have logic to differentiate types that appear in blocks.
        // So, we clone and avoid differentiating them.
        //
        if (auto origType = as<IRType>(origInst))
            return InstPair(cloneInst(&cloneEnv, builder, origType), nullptr);
        
        // If we reach this statement, the instruction type is likely unhandled.
        getSink()->diagnose(origInst->sourceLoc,
                    Diagnostics::unimplemented,
                    "this instruction cannot be differentiated");

        return InstPair(nullptr, nullptr);
    }
};

struct IRWorkQueue
{
    // Work list to hold the active set of insts whose children
    // need to be looked at.
    //
    List<IRInst*> workList;
    HashSet<IRInst*> workListSet;

    void push(IRInst* inst)
    {
        if(!inst) return;
        if(workListSet.Contains(inst)) return;
        
        workList.add(inst);
        workListSet.Add(inst);
    }

    IRInst* pop()
    {
        if (workList.getCount() != 0)
        {
            IRInst* topItem = workList.getFirst();
            // TODO(Sai): Repeatedly calling removeAt() can be really slow.
            // Consider a specialized data structure or using removeLast()
            // 
            workList.removeAt(0);
            workListSet.Remove(topItem);
            return topItem;
        }
        return nullptr;
    }

    IRInst* peek()
    {
        return workList.getFirst();
    }
};

struct JVPDerivativeContext
{

    DiagnosticSink* getSink()
    {
        return sink;
    }

    bool processModule()
    {
        // We start by initializing our shared IR building state,
        // since we will re-use that state for any code we
        // generate along the way.
        //
        SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
        sharedBuilder->init(module);
    
        IRBuilder builderStorage(sharedBuilderStorage);
        IRBuilder* builder = &builderStorage;

        // Process all JVPDifferentiate instructions (kIROp_JVPDifferentiate), by 
        // generating derivative code for the referenced function.
        //
        bool modified = processReferencedFunctions(builder);

        // Replaces IRDifferentialPairType with an auto-generated struct,
        // IRDifferentialPairGetDifferential with 'differential' field access,
        // IRDifferentialPairGetPrimal with 'primal' field access, and
        // IRMakeDifferentialPair with an IRMakeStruct.
        // 
        modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage));

        return modified;
    }

    IRInst* lookupJVPReference(IRInst* primalFunction)
    {
        if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>())
            return jvpDefinition->getJVPFunc();
        
        return nullptr;
    }

    // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate),
    // then check that the referenced function is marked correctly for differentiation.
    //
    bool processReferencedFunctions(IRBuilder* builder)
    {
        IRWorkQueue* workQueue = &(workQueueStorage);

        // Put the top-level inst into the queue.
        workQueue->push(module->getModuleInst());
        
        // Keep processing items until the queue is complete.
        while (IRInst* workItem = workQueue->pop())
        {   
            for(auto child = workItem->getFirstChild(); child; child = child->getNextInst())
            {
                // Either the child instruction has more children (func/block etc..)
                // and we add it to the work list for further processing, or 
                // it's an ordinary inst in which case we check if it's a JVPDifferentiate
                // instruction.
                //
                if (child->getFirstChild() != nullptr)
                    workQueue->push(child);
                
                if (auto jvpDiffInst = as<IRJVPDifferentiate>(child))
                {
                    auto baseFunction = jvpDiffInst->getBaseFn();
                    // If the JVP Reference already exists, no need to
                    // differentiate again.
                    //
                    if(lookupJVPReference(baseFunction)) continue;

                    if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction)))
                    {
                        IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction));
                        builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction);
                        workQueue->push(jvpFunction);
                    }
                    else
                    {
                        // TODO(Sai): This would probably be better with a more specific
                        // error code.
                        getSink()->diagnose(jvpDiffInst->sourceLoc,
                            Diagnostics::internalCompilerError,
                            "Cannot differentiate functions not marked for differentiation");
                    }
                }
            }
        }

        return true;
    }

    // Run through all the global-level instructions, 
    // looking for callables.
    // Note: We're only processing global callables (IRGlobalValueWithCode)
    // for now.
    // 
    bool processMarkedGlobalFunctions(IRBuilder* builder)
    {
        for (auto inst : module->getGlobalInsts())
        {
            // If the instr is a callable, get all the basic blocks
            if (auto callable = as<IRGlobalValueWithCode>(inst))
            {
                if (isFunctionMarkedForJVP(callable))
                {   
                    SLANG_ASSERT(as<IRFunc>(callable));

                    IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable));
                    builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);

                    unmarkForJVP(callable);
                }
            }
        }
        return true;
    }

    IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext* diffContext)
    {
        if (diffContext->isInterfaceAvailable)
        {
            if (auto pairType = as<IRDifferentialPairType>(type))
            {
                builder->setInsertBefore(pairType);

                auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
                    builder,
                    pairType->getValueType());

                pairType->replaceUsesWith(diffPairStructType);
                pairType->removeAndDeallocate();

                return diffPairStructType;
            }
            else if (auto loweredStructType = as<IRStructType>(type))
            {
                // Already lowered to struct.
                return loweredStructType;
            }
        }
        return nullptr;
    }

    IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
    {
        
        if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
        {
            auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext);
            
            builder->setInsertBefore(makePairInst);
            
            List<IRInst*> operands;
            operands.add(makePairInst->getPrimalValue());
            operands.add(makePairInst->getDifferentialValue());

            auto makeStructInst = builder->emitMakeStruct(as<IRStructType>(diffPairStructType), operands);
            makePairInst->replaceUsesWith(makeStructInst);
            makePairInst->removeAndDeallocate();

            return makeStructInst;
        }
        
        return nullptr;
    }

    IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
    {
        
        if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
        {
            lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext);

            builder->setInsertBefore(getDiffInst);
            
            auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
            getDiffInst->replaceUsesWith(diffFieldExtract);
            getDiffInst->removeAndDeallocate();

            return diffFieldExtract;
        }
        else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
        {
            lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext);

            builder->setInsertBefore(getPrimalInst);

            auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
            getPrimalInst->replaceUsesWith(primalFieldExtract);
            getPrimalInst->removeAndDeallocate();

            return primalFieldExtract;
        }
        
        return nullptr;
    }

    bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren, DifferentiableTypeConformanceContext* diffContext)
    {
        bool modified = false;

        // Create a new sub-context to scan witness tables inside workItem 
        // (mainly relevant if instWithChildren is a generic scope)
        // 
        auto subContext = DifferentiableTypeConformanceContext(diffContext, instWithChildren);
        (&pairBuilderStorage)->diffConformanceContext = (&subContext);

        for (auto child = instWithChildren->getFirstChild(); child; )
        {
            // Make sure the builder is at the right level.
            builder->setInsertInto(instWithChildren);

            auto nextChild = child->getNextInst();

            switch (child->getOp())
            {
                case kIROp_DifferentialPairType:
                    lowerPairType(builder, as<IRType>(child), &subContext);
                    break;
                
                case kIROp_DifferentialPairGetDifferential:
                case kIROp_DifferentialPairGetPrimal:
                    lowerPairAccess(builder, child, &subContext);
                    break;
                
                case kIROp_MakeDifferentialPair:
                    lowerMakePair(builder, child, &subContext);
                    break;
                
                default:
                    if (child->getFirstChild())
                        modified = processPairTypes(builder, child, (&subContext)) | modified;
            }

            child = nextChild;
        }

        // Reset the context back to the parent.
        (&pairBuilderStorage)->diffConformanceContext = diffContext;

        return modified;
    }

    // Checks decorators to see if the function should
    // be differentiated (kIROp_JVPDerivativeMarkerDecoration)
    // 
    bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable)
    {
        for(auto decoration = callable->getFirstDecoration(); 
            decoration;
            decoration = decoration->getNextDecoration())
        {
            if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
            {
                return true;
            }
        }
        return false;
    }

    // Removes the JVPDerivativeMarkerDecoration from the provided callable, 
    // if it exists.
    //
    void unmarkForJVP(IRGlobalValueWithCode* callable)
    {
        for(auto decoration = callable->getFirstDecoration(); 
            decoration;
            decoration = decoration->getNextDecoration())
        {
            if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
            {
                decoration->removeAndDeallocate();
                return;
            }
        }
    }

    List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType)
    {
        List<IRParam*> params;
        for(UIndex i = 0; i < dataType->getParamCount(); i++)
        {
            params.add(
                builder->emitParam(dataType->getParamType(i)));
        }
        return params;
    }

    // Perform forward-mode automatic differentiation on 
    // the intstructions.
    //
    IRFunc* emitJVPFunction(IRBuilder* builder,
                            IRFunc*    primalFn)
    {
        eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn);

        builder->setInsertBefore(primalFn->getNextInst()); 

        auto jvpFn = builder->createFunc();
        
        SLANG_ASSERT(as<IRFuncType>(primalFn->getFullType()));
        IRType* jvpFuncType = transcriberStorage.differentiateFunctionType(
            builder,
            as<IRFuncType>(primalFn->getFullType()));
        jvpFn->setFullType(jvpFuncType);

        if (auto jvpName = getJVPFuncName(builder, primalFn))
            builder->addNameHintDecoration(jvpFn, jvpName);

        builder->setInsertInto(jvpFn);
        
        // Emit a block instruction for every block in the function, and map it as the 
        // corresponding differential.
        //
        for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
        {
            auto jvpBlock = builder->emitBlock();
            transcriberStorage.mapDifferentialInst(block, jvpBlock);
            transcriberStorage.mapPrimalInst(block, jvpBlock);
        }

        // Go back over the blocks, and process the children of each block.
        for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
        {
            auto jvpBlock = as<IRBlock>(transcriberStorage.lookupDiffInst(block, block));
            SLANG_ASSERT(jvpBlock);
            emitJVPBlock(builder, block, jvpBlock);
        }

        return jvpFn;
    }

    IRStringLit* getJVPFuncName(IRBuilder*    builder,
                                IRFunc*       func)
    {
        auto oldLoc = builder->getInsertLoc();
        builder->setInsertBefore(func);
        
        IRStringLit* name = nullptr;
        if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
        {
            name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice());
        }
        else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
        {
            name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice());
        }

        builder->setInsertLoc(oldLoc);

        return name;
    }

    IRBlock* emitJVPBlock(IRBuilder*    builder, 
                          IRBlock*      origBlock,
                          IRBlock*      jvpBlock = nullptr)
    {   
        JVPTranscriber* transcriber = &(transcriberStorage);

        // Create if not already created, and then insert into new block.
        if (!jvpBlock)
            jvpBlock = builder->emitBlock();
        else
            builder->setInsertInto(jvpBlock);

        
        // First transcribe every parameter in the block.
        for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
        {
            transcriber->transcribe(builder, param);
        }

        // Then, run through every instruction and use the transcriber to generate the appropriate
        // derivative code.
        //
        for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
        {
            transcriber->transcribe(builder, child);
        }

        return jvpBlock;
    }

    JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
        module(module), sink(sink),
        diffConformanceContextStorage(module->getModuleInst()),
        pairBuilderStorage(&diffConformanceContextStorage)
    {
        transcriberStorage.sink = sink;
        transcriberStorage.diffConformanceContext = &(diffConformanceContextStorage);
        transcriberStorage.pairBuilder = &(pairBuilderStorage);
    }

    protected:

    // This type passes over the module and generates
    // forward-mode derivative versions of functions 
    // that are explicitly marked for it.
    //
    IRModule*                       module;

    // Shared builder state for our derivative passes.
    SharedIRBuilder                 sharedBuilderStorage;

    // A transcriber object that handles the main job of 
    // processing instructions while maintaining state.
    //
    JVPTranscriber                  transcriberStorage;
    
    // Diagnostic object from the compile request for
    // error messages.
    DiagnosticSink*                 sink;

    // Work queue to hold a stream of instructions that need
    // to be checked for references to derivative functions.
    IRWorkQueue                     workQueueStorage;

    // Context to find and manage the witness tables for types 
    // implementing `IDifferentiable`
    DifferentiableTypeConformanceContext diffConformanceContextStorage;

    // Builder for dealing with differential pair types.
    DifferentialPairTypeBuilder     pairBuilderStorage;

};

// Set up context and call main process method.
//
bool processJVPDerivativeMarkers(
        IRModule*                           module,
        DiagnosticSink*                     sink,
        IRJVPDerivativePassOptions const&)
{   
    // Simplify module to remove dead code.
    IRDeadCodeEliminationOptions options;
    options.keepExportsAlive = true;
    options.keepLayoutsAlive = true;
    eliminateDeadCode(module, options);

    JVPDerivativeContext context(module, sink);

    return context.processModule();
}

}
back to top