https://github.com/shader-slang/slang
Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-autodiff-fwd.cpp
// slang-ir-autodiff-fwd.cpp
#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"

#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-single-return.h"
#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-validate.h"
#include "slang-ir-inline.h"
#include "slang-ir-init-local-var.h"

namespace Slang
{

IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
{
    SLANG_UNUSED(func);

    List<IRType*> newParameterTypes;
    IRType* diffReturnType;

    for (UIndex i = 0; i < funcType->getParamCount(); i++)
    {
        auto origType = funcType->getParamType(i);
        origType = (IRType*) findOrTranscribePrimalInst(builder, origType);
        if (auto diffPairType = tryGetDiffPairType(builder, origType))
            newParameterTypes.add(diffPairType);
        else
            newParameterTypes.add(origType);
    }

    // Transcribe return type to a pair.
    // This will be void if the primal return type is non-differentiable.
    //
    auto origResultType = (IRType*)findOrTranscribePrimalInst(builder, funcType->getResultType());
    if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
        diffReturnType = returnPairType;
    else
        diffReturnType = origResultType;

    return builder->getFuncType(newParameterTypes, diffReturnType);
}

void ForwardDiffTranscriber::generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc)
{
    IRBuilder builder(diffFunc);
    builder.setInsertInto(diffFunc);
    auto block = builder.emitBlock();
    builder.markInstAsMixedDifferential(block);

    for (auto param : primalFunc->getParams())
    {
        transcribeFuncParam(&builder, param, param->getFullType());
    }
    List<IRParam*> diffParams;
    for (auto param : diffFunc->getParams())
    {
        diffParams.add(param);
    }
    auto emitDiffPairVal = [&](IRDifferentialPairTypeBase* pairType)
    {
        auto primal = builder.emitDefaultConstruct(pairType->getValueType());
        builder.markInstAsPrimal(primal);
        auto diff = getDifferentialZeroOfType(&builder, pairType->getValueType());
        builder.markInstAsDifferential(diff, primal->getDataType());

        auto val = builder.emitMakeDifferentialPair(pairType, primal, diff);
        builder.markInstAsMixedDifferential(val);

        return val;
        
    };
    for (auto param : diffParams)
    {
        if (auto outType = as<IROutTypeBase>(param->getFullType()))
        {
            if (isRelevantDifferentialPair(outType))
            {
                auto pairType = as<IRDifferentialPairTypeBase>(outType->getValueType());
                auto val = emitDiffPairVal(pairType);
                auto store = builder.emitStore(param, val);
                builder.markInstAsMixedDifferential(store);
            }
            else
            {
                auto val = builder.emitDefaultConstruct(outType->getValueType());
                builder.markInstAsPrimal(val);

                auto store = builder.emitStore(param, val);
                builder.markInstAsPrimal(store);

            }
        }
    }
    if (isRelevantDifferentialPair(diffFunc->getResultType()))
    {
        auto pairType = as<IRDifferentialPairTypeBase>(diffFunc->getResultType());
        auto val = emitDiffPairVal(pairType);
        auto returnInst = builder.emitReturn(val);
        builder.markInstAsMixedDifferential(val);
        builder.markInstAsMixedDifferential(returnInst);
    }
    else
    {
        auto retVal = builder.emitDefaultConstruct(diffFunc->getResultType());
        auto returnInst = builder.emitReturn(retVal);
        builder.markInstAsPrimal(retVal);
        builder.markInstAsPrimal(returnInst);
    }
}

// Returns "d<var-name>" to use as a name hint for variables and parameters.
// If no primal name is available, returns a blank string.
// 
String ForwardDiffTranscriber::getJVPVarName(IRInst* origVar)
{
    if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
    {
        return ("d" + String(namehintDecoration->getName()));
    }

    return String("");
}

InstPair ForwardDiffTranscriber::transcribeUndefined(IRBuilder* builder, IRInst* origInst)
{
    auto primalVal = maybeCloneForPrimalInst(builder, origInst);

    if (IRType* const diffType = differentiateType(builder, origInst->getFullType()))
    {
        auto dzero = getDifferentialZeroOfType(builder, origInst->getFullType());
        if (dzero)
        {
            return InstPair(primalVal, dzero);
        }
    }
    return InstPair(primalVal, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRInst* origInst)
{
    auto primalVal = maybeCloneForPrimalInst(builder, origInst);

    IRInst* diffVal = nullptr;

    if (IRType* const diffType = differentiateType(builder, origInst->getFullType()))
    {
        if (auto diffOperand = findOrTranscribeDiffInst(builder, origInst->getOperand(0)))
        {
            diffVal = builder->emitReinterpret(diffType, diffOperand);
        }
    }

    return InstPair(primalVal, diffVal);
}

InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
    if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
    {
        IRVar* diffVar = builder->emitVar(diffType);
        SLANG_ASSERT(diffVar);

        auto diffNameHint = getJVPVarName(origVar);
        if (diffNameHint.getLength() > 0)
            builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice());

        return InstPair(maybeCloneForPrimalInst(builder, origVar), diffVar);
    }
    return InstPair(maybeCloneForPrimalInst(builder, origVar), nullptr);
}

InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
{
    SLANG_ASSERT(origArith->getOperandCount() == 2);

    IRInst* primalArith = maybeCloneForPrimalInst(builder, origArith);
    
    auto origLeft = origArith->getOperand(0);
    auto origRight = origArith->getOperand(1);

    auto primalLeft = findOrTranscribePrimalInst(builder, origLeft);
    auto primalRight = findOrTranscribePrimalInst(builder, origRight);

    auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
    auto diffRight = findOrTranscribeDiffInst(builder, origRight);


    if (diffLeft || diffRight)
    {
        diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());

        bool diffRightIsZero = (diffRight == nullptr);
        diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
        diffRightIsZero = diffRightIsZero || isZero(diffRight);

        auto resultType = primalArith->getDataType();
        auto origResultType = origArith->getDataType();
        auto diffType = (IRType*)differentiateType(builder, origResultType);

        switch(origArith->getOp())
        {
        case kIROp_Add:
            {
                auto diffAdd = builder->emitAdd(diffType, diffLeft, diffRight);
                builder->markInstAsDifferential(diffAdd, resultType);

                return InstPair(primalArith, diffAdd);
            }

        case kIROp_Mul:
            {
                auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight);
                auto diffRightTimesLeft = builder->emitMul(diffType, diffRight, primalLeft);
                builder->markInstAsDifferential(diffLeftTimesRight, resultType);
                builder->markInstAsDifferential(diffRightTimesLeft, resultType);

                auto diffAdd = builder->emitAdd(diffType, diffLeftTimesRight, diffRightTimesLeft);
                builder->markInstAsDifferential(diffAdd, resultType);

                return InstPair(primalArith, diffAdd);
            }

        case kIROp_Sub:
            {
                auto diffSub = builder->emitSub(diffType, diffLeft, diffRight);
                builder->markInstAsDifferential(diffSub, resultType);

                return InstPair(primalArith, diffSub);
            }
        case kIROp_Div:
            {
                if (diffRightIsZero)
                {
                    // Special case the dRight = 0 case here since it would be difficult
                    // to optimize out in the future.
                    IRInst* diff = nullptr;
                    if (auto constant = as<IRFloatLit>(primalRight))
                    {
                        diff = builder->emitMul(
                            diffType,
                            diffLeft,
                            builder->getFloatValue(
                                constant->getDataType(), 1.0 / constant->getValue()));
                        builder->markInstAsDifferential(diff, resultType);
                    }
                    else
                    {
                        diff = builder->emitDiv(diffType, diffLeft, primalRight);
                        builder->markInstAsDifferential(diff, resultType);
                    }
                    return InstPair(primalArith, diff);
                }
                else
                {
                    auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight);
                    builder->markInstAsDifferential(diffLeftTimesRight, resultType);

                    auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight);
                    builder->markInstAsDifferential(diffRightTimesLeft, resultType);
                
                    auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft);
                    builder->markInstAsDifferential(diffSub, resultType);

                    auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight);
                    builder->markInstAsPrimal(diffMul);

                    auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul);
                    builder->markInstAsDifferential(diffDiv, resultType);

                    return InstPair(primalArith, diffDiv);
                }
            }
        default:
            getSink()->diagnose(origArith->sourceLoc,
                Diagnostics::unimplemented,
                "this arithmetic instruction cannot be differentiated");
        }
    }

    return InstPair(primalArith, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
{
    SLANG_ASSERT(origLogic->getOperandCount() == 2);

    // Boolean operations are not differentiable. For the linearization
    // pass, we do not need to do anything but copy them over to the ne
    // function.
    auto primalLogic = maybeCloneForPrimalInst(builder, origLogic);
    return InstPair(primalLogic, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeSelect(IRBuilder* builder, IRInst* origSelect)
{
    auto primalCondition = lookupPrimalInst(builder, origSelect->getOperand(0));

    auto origLeft = origSelect->getOperand(1);
    auto origRight = origSelect->getOperand(2);

    auto primalLeft = findOrTranscribePrimalInst(builder, origLeft);
    auto primalRight = findOrTranscribePrimalInst(builder, origRight);

    auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
    auto diffRight = findOrTranscribeDiffInst(builder, origRight);

    auto primalSelect = maybeCloneForPrimalInst(builder, origSelect);

    // If both sides have no differential, skip
    if (diffLeft || diffRight)
    {
        diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
        diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());

        auto diffType = differentiateType(builder, origSelect->getDataType());

        return InstPair(
            primalSelect,
            builder->emitIntrinsicInst(
                diffType,
                kIROp_Select,
                3,
                List<IRInst*>(primalCondition, diffLeft, diffRight).getBuffer()));
    }
    
    return InstPair(primalSelect, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
    auto origPtr = origLoad->getPtr();
    auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr);
    auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType());
    if (primalPtrType)
    {
        if (auto diffPairType = as<IRDifferentialPairType>(primalPtrType->getValueType()))
        {
            // Special case load from an `out` param, which will not have corresponding `diff` and
            // `primal` insts yet.

            // TODO: Could we move this load to _after_ DifferentialPairGetPrimal, 
            // and DifferentialPairGetDifferential?
            // 
            auto load = builder->emitLoad(primalPtr);
            builder->markInstAsMixedDifferential(load, diffPairType);

            auto primalElement = builder->emitDifferentialPairGetPrimal(load);
            auto diffElement = builder->emitDifferentialPairGetDifferential(
                (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load);
            return InstPair(primalElement, diffElement);
        }
    }

    auto primalLoad = maybeCloneForPrimalInst(builder, origLoad);
    IRInst* diffLoad = nullptr;
    if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
    {
        // Default case, we're loading from a known differential inst.
        diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
    } 
    return InstPair(primalLoad, diffLoad);
}

InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* origStore)
{
    IRInst* origStoreLocation = origStore->getPtr();
    IRInst* origStoreVal = origStore->getVal();
    auto primalStoreLocation = lookupPrimalInst(builder, origStoreLocation, nullptr);
    auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
    auto primalStoreVal = lookupPrimalInst(builder, origStoreVal, nullptr);
    auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);

    if (!diffStoreLocation)
    {
        auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType());
        if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType()))
        {
            auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
            builder->markInstAsMixedDifferential(diffStoreVal, diffPairType);

            auto store = builder->emitStore(primalStoreLocation, valToStore);
            builder->markInstAsMixedDifferential(store, diffPairType);

            return InstPair(store, nullptr);
        }
    }

    auto primalStore = maybeCloneForPrimalInst(builder, origStore);

    IRInst* diffStore = nullptr;

    // If the stored value has a differential version, 
    // emit a store instruction for the differential parameter.
    // Otherwise, emit nothing since there's nothing to load.
    // 
    if (diffStoreLocation && diffStoreVal)
    {
        // Default case, storing the entire type (and not a member)
        diffStore = as<IRStore>(
            builder->emitStore(diffStoreLocation, diffStoreVal));
        
        return InstPair(primalStore, diffStore);
    }

    return InstPair(primalStore, nullptr);
}

// Since int/float literals are sometimes nested inside an IRConstructor
// instruction, we check to make sure that the nested instr is a constant
// and then return nullptr. Literals do not need to be differentiated.
//
InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
{   
    IRInst* primalConstruct = maybeCloneForPrimalInst(builder, origConstruct);
    
    // Check if the output type can be differentiated. If it cannot be 
    // differentiated, don't differentiate the inst
    // 
    auto primalConstructType = (IRType*)findOrTranscribePrimalInst(builder, origConstruct->getDataType());
    // TODO: Need to update this to generate derivatives on a per-key basis
    if (auto diffConstructType = differentiateType(builder, primalConstructType))
    {
        UCount operandCount = origConstruct->getOperandCount();

        List<IRInst*> diffOperands;
        for (UIndex ii = 0; ii < operandCount; ii++)
        {
            // If the operand has a differential version, replace the original with
            // the differential. Otherwise, use a zero.
            // 
            if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr))
                diffOperands.add(diffInst);
            else 
            {
                auto operandDataType = origConstruct->getOperand(ii)->getDataType();
                if (const auto diffOperandType = differentiateType(builder, operandDataType))
                {
                    operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType);
                    diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
                }
            }
        }
        
        return InstPair(
            primalConstruct, 
            builder->emitIntrinsicInst(
                diffConstructType,
                origConstruct->getOp(),
                diffOperands.getCount(),
                diffOperands.getBuffer()));
    }
    else
    {
        return InstPair(primalConstruct, nullptr);
    }
}

InstPair ForwardDiffTranscriber::transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct)
{
    IRInst* primalMakeStruct = maybeCloneForPrimalInst(builder, origMakeStruct);

    // Check if the output type can be differentiated. If it cannot be 
    // differentiated, don't differentiate the inst
    // 
    auto primalStructType = (IRType*)findOrTranscribePrimalInst(builder, origMakeStruct->getDataType());
    if (auto diffStructType = differentiateType(builder, primalStructType))
    {
        auto primalStruct = as<IRStructType>(getResolvedInstForDecorations(primalStructType));
        SLANG_RELEASE_ASSERT(primalStruct);

        List<IRInst*> diffOperands;
        UIndex ii = 0;
        for (auto field : primalStruct->getFields())
        {
            SLANG_RELEASE_ASSERT(ii < origMakeStruct->getOperandCount());

            // If this field is not differentiable, skip the operand.
            if (!field->getKey()->findDecoration<IRDerivativeMemberDecoration>())
            {
                ii++;
                continue;
            }

            // If the operand has a differential version, replace the original with
            // the differential. Otherwise, use a zero.
            // 
            if (auto diffInst = lookupDiffInst(origMakeStruct->getOperand(ii), nullptr))
            {
                diffOperands.add(diffInst);
            }
            else
            {
                auto operandDataType = origMakeStruct->getOperand(ii)->getDataType();
                auto diffOperandType = differentiateType(builder, operandDataType);
                SLANG_RELEASE_ASSERT(diffOperandType);
                operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType);
                diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
            }
            ii++;
        }

        return InstPair(
            primalMakeStruct,
            builder->emitIntrinsicInst(
                diffStructType,
                kIROp_MakeStruct,
                diffOperands.getCount(),
                diffOperands.getBuffer()));
    }
    else
    {
        return InstPair(primalMakeStruct, nullptr);
    }
}

static bool _isDifferentiableFunc(IRInst* func)
{
    func = getResolvedInstForDecorations(func);
    for (auto decor = func->getFirstDecoration(); decor; decor = decor->getNextDecoration())
    {
        switch (decor->getOp())
        {
        case kIROp_ForwardDerivativeDecoration:
        case kIROp_ForwardDifferentiableDecoration:
        case kIROp_BackwardDerivativeDecoration:
        case kIROp_BackwardDifferentiableDecoration:
        case kIROp_UserDefinedBackwardDerivativeDecoration:
            return true;
        }
    }
    return false;
}

static IRFuncType* _getCalleeActualFuncType(IRInst* callee)
{
    auto type = callee->getFullType();
    if (auto funcType = as<IRFuncType>(type))
        return funcType;
    if (auto specialize = as<IRSpecialize>(callee))
        return as<IRFuncType>(findGenericReturnVal(as<IRGeneric>(specialize->getBase()))->getFullType());
    return nullptr;
}

IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee)
{
    if (auto func = as<IRFunc>(callee))
    {
        if (auto decor = func->findDecoration<IRPrimalSubstituteDecoration>())
            return decor->getPrimalSubstituteFunc();
    }
    else if (auto specialize = as<IRSpecialize>(callee))
    {
        auto innerGen = as<IRGeneric>(specialize->getBase());
        if (!innerGen)
            return callee;
        auto innerFunc = findGenericReturnVal(innerGen);
        if (auto decor = innerFunc->findDecoration<IRPrimalSubstituteDecoration>())
        {
            auto substSpecialize = as<IRSpecialize>(decor->getPrimalSubstituteFunc());
            SLANG_RELEASE_ASSERT(substSpecialize);
            SLANG_RELEASE_ASSERT(substSpecialize->getArgCount() == specialize->getArgCount());
            List<IRInst*> args;
            for (UInt i = 0; i < specialize->getArgCount(); i++)
                args.add(specialize->getArg(i));
            return builder->emitSpecializeInst(
                callee->getFullType(),
                substSpecialize->getBase(),
                (UInt)args.getCount(),
                args.getBuffer());
        }
    }
    return callee;
}

// Differentiating a call instruction here is primarily about generating
// an appropriate call list based on whichever parameters have differentials 
// in the current transcription context.
// 
InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* origCall)
{   
    
    IRInst* origCallee = origCall->getCallee();

    if (!origCallee)
    {
        // Note that this can only happen if the callee is a result
        // of a higher-order operation. For now, we assume that we cannot
        // differentiate such calls safely.
        // TODO(sai): Should probably get checked in the front-end.
        //
        getSink()->diagnose(origCall->sourceLoc,
            Diagnostics::internalCompilerError,
            "attempting to differentiate unresolved callee");
        
        return InstPair(nullptr, nullptr);
    }

    auto primalCallee = findOrTranscribePrimalInst(builder, origCallee);
    auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee);

    IRInst* diffCallee = nullptr;
    if (substPrimalCallee == primalCallee)
    {
        instMapD.tryGetValue(origCallee, diffCallee);
    }
    else
    {
        if (_isDifferentiableFunc(origCallee))
            diffCallee = findOrTranscribeDiffInst(builder, origCallee);
        primalCallee = substPrimalCallee;
    }

    if (diffCallee)
    {
    }
    else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
    {
        // If the user has already provided an differentiated implementation, use that.
        diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
    }
    else if (_isDifferentiableFunc(primalCallee))
    {
        // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
        // to generate the implementation.
        diffCallee = builder->emitForwardDifferentiateInst(
            differentiateFunctionType(
                builder, primalCallee, as<IRFuncType>(primalCallee->getFullType())),
            primalCallee);
    }

    if (!diffCallee)
    {
        // The callee is non differentiable, just return primal value with null diff value.
        IRInst* primalCall = maybeCloneForPrimalInst(builder, origCall);
        return InstPair(primalCall, nullptr);
    }

    auto calleeType = _getCalleeActualFuncType(primalCallee);
    SLANG_ASSERT(calleeType);
    SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount());

    auto diffCalleeType = _getCalleeActualFuncType(diffCallee);
    SLANG_ASSERT(diffCalleeType);
    SLANG_RELEASE_ASSERT(diffCalleeType->getParamCount() == origCall->getArgCount());

    auto placeholderCall = builder->emitCallInst(nullptr, builder->emitUndefined(builder->getTypeKind()), 0, nullptr);
    builder->setInsertBefore(placeholderCall);
    IRBuilder argBuilder = *builder;
    IRBuilder afterBuilder = argBuilder;
    afterBuilder.setInsertAfter(placeholderCall);

    List<IRInst*> args;
    // Go over the parameter list and create pairs for each input (if required)
    for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
    {
        auto origArg = origCall->getArg(ii);
        auto primalArg = findOrTranscribePrimalInst(&argBuilder, origArg);
        SLANG_ASSERT(primalArg);

        auto origType = origCall->getArg(ii)->getDataType();
        auto primalType = primalArg->getDataType();
        auto originalParamType = calleeType->getParamType(ii);
        auto diffParamType = diffCalleeType->getParamType(ii);
        if (!isNoDiffType(originalParamType))
        {
            if (isNoDiffType(primalType))
            {
                while (auto attrType = as<IRAttributedType>(primalType))
                    primalType = attrType->getBaseType();
                while (auto attrType = as<IRAttributedType>(origType))
                    origType = attrType->getBaseType();
            }
            if (auto pairType = tryGetDiffPairType(&argBuilder, primalType))
            {
                auto pairPtrType = as<IRPtrTypeBase>(pairType);
                auto pairValType = as<IRDifferentialPairType>(
                    pairPtrType ? pairPtrType->getValueType() : pairType);
                auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType);
                if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
                {
                    // Create temp var to pass in/out arguments.
                    auto srcVar = argBuilder.emitVar(pairValType);
                    argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType());

                    auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg);
                    if (ptrParamType->getOp() == kIROp_InOutType)
                    {
                        // Set initial value.
                        auto primalVal = argBuilder.emitLoad(primalArg);
                        auto diffArgVal = diffArg;
                        if (!diffArg)
                            diffArgVal = getDifferentialZeroOfType(builder, (IRType*)pairValType->getValueType());
                        else
                        {
                            diffArgVal = argBuilder.emitLoad(diffArg);
                            argBuilder.markInstAsDifferential(diffArgVal, pairValType->getValueType());
                        }
                        auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal);
                        argBuilder.markInstAsMixedDifferential(initVal, primalType);
                        auto store = argBuilder.emitStore(srcVar, initVal);
                        argBuilder.markInstAsMixedDifferential(store, primalType);
                    }
                    if (as<IROutTypeBase>(ptrParamType))
                    {
                        // Read back new value.
                        auto newVal = afterBuilder.emitLoad(srcVar);
                        afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType());
                        auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal);
                        afterBuilder.emitStore(primalArg, newPrimalVal);

                        if (diffArg)
                        {
                            auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal);
                            afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType());

                            auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal);
                            afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType());
                        }
                    }
                    args.add(srcVar);
                    continue;
                }
                else
                {
                    auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg);
                    if (!diffArg)
                        diffArg = getDifferentialZeroOfType(&argBuilder, primalType);

                    // If a pair type can be formed, this must be non-null.
                    SLANG_RELEASE_ASSERT(diffArg);

                    auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg);
                    argBuilder.markInstAsMixedDifferential(diffPair, pairType);

                    args.add(diffPair);
                    continue;
                }
                
            }
        }
        // Argument is not differentiable.
        // Add original/primal argument.
        args.add(primalArg);
    }
    
    IRType* diffReturnType = nullptr;
    diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType());

    if (!diffReturnType)
    {
        diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType());
    }

    auto callInst = argBuilder.emitCallInst(
        diffReturnType,
        diffCallee,
        args);
    placeholderCall->removeAndDeallocate();
    argBuilder.markInstAsMixedDifferential(callInst, diffReturnType);
    argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee);

    *builder = afterBuilder;

    if (diffReturnType->getOp() == kIROp_DifferentialPairType)
    {
        IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst);
        auto diffType = differentiateType(&afterBuilder, origCall->getFullType());
        IRInst* diffResultValue = afterBuilder.emitDifferentialPairGetDifferential(diffType, callInst);
        return InstPair(primalResultValue, diffResultValue);
    }
    else
    {
        // Return the inst itself if the return value is non-differentiable.
        // This is fine since these values should only be used by non-differentiable code.
        // 
        return InstPair(callInst, callInst);
    }
}

InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
{
    IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle);

    if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
    {
        List<IRInst*> swizzleIndices;
        for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
            swizzleIndices.add(origSwizzle->getElementIndex(ii));
        
        return InstPair(
            primalSwizzle,
            builder->emitSwizzle(
                differentiateType(builder, primalSwizzle->getDataType()),
                diffBase,
                origSwizzle->getElementCount(),
                swizzleIndices.getBuffer()));
    }

    return InstPair(primalSwizzle, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
{
    IRInst* primalInst = maybeCloneForPrimalInst(builder, origInst);

    UCount operandCount = origInst->getOperandCount();

    List<IRInst*> diffOperands;
    for (UIndex ii = 0; ii < operandCount; ii++)
    {
        // If the operand has a differential version, replace the original with the 
        // differential.
        // Otherwise, abandon the differentiation attempt and assume that origInst 
        // cannot (or does not need to) be differentiated.
        // 
        if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr))
            diffOperands.add(diffInst);
        else
            return InstPair(primalInst, nullptr);
    }
    
    return InstPair(
        primalInst, 
        builder->emitIntrinsicInst(
            differentiateType(builder, origInst->getDataType()),
            origInst->getOp(),
            operandCount,
            diffOperands.getBuffer()));
}

InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
{
    switch(origInst->getOp())
    {
        case kIROp_unconditionalBranch:
        case kIROp_loop:
            auto origBranch = as<IRUnconditionalBranch>(origInst);
            auto targetBlock = origBranch->getTargetBlock();

            // Grab the differentials for any phi nodes.
            List<IRInst*> newArgs;
            for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
            {
                auto origParam = getParamAt(targetBlock, ii);
                auto origArg = origBranch->getArg(ii);
                auto primalArg = lookupPrimalInst(builder, origArg);
                newArgs.add(primalArg);

                if (differentiateType(builder, origParam->getDataType()))
                {
                    auto diffArg = lookupDiffInst(origArg, nullptr);
                    if (diffArg)
                        newArgs.add(diffArg);
                    else
                        newArgs.add(
                            getDifferentialZeroOfType(builder, origArg->getDataType()));
                }
            }

            IRInst* diffBranch = nullptr;
            if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
            {
                if (auto origLoop = as<IRLoop>(origInst))
                {
                    auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
                    auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
                    List<IRInst*> operands;
                    operands.add(diffBlock);
                    operands.add(breakBlock);
                    operands.add(continueBlock);
                    operands.addRange(newArgs);
                    diffBranch = builder->emitIntrinsicInst(
                        nullptr,
                        kIROp_loop,
                        operands.getCount(),
                        operands.getBuffer());
                    if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>())
                        builder->addLoopMaxItersDecoration(diffBranch, maxItersDecoration->getMaxIters());
                }
                else
                {
                    diffBranch = builder->emitBranch(
                        as<IRBlock>(diffBlock),
                        newArgs.getCount(),
                        newArgs.getBuffer());
                }
            }

            // For now, every block in the original fn must have a corresponding
            // block to compute *both* primals and derivatives (i.e linearized block)
            SLANG_ASSERT(diffBranch);

            // Since blocks always compute both primals and differentials, the branch
            // instructions are also always mixed.
            // 
            builder->markInstAsMixedDifferential(diffBranch);

            return InstPair(diffBranch, diffBranch);
    }

    getSink()->diagnose(
        origInst->sourceLoc,
        Diagnostics::unimplemented,
        "attempting to differentiate unhandled control flow");

    return InstPair(nullptr, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder*, IRInst* origInst)
{
    switch(origInst->getOp())
    {
        case kIROp_FloatLit:
        case kIROp_IntLit:
            return InstPair(origInst, nullptr);
        case kIROp_VoidLit:
            return InstPair(origInst, origInst);
    }

    getSink()->diagnose(
        origInst->sourceLoc,
        Diagnostics::unimplemented,
        "attempting to differentiate unhandled const type");

    return InstPair(nullptr, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
{
    auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
    List<IRInst*> primalArgs;
    for (UInt i = 0; i < origSpecialize->getArgCount(); i++)
    {
        primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i)));
    }
    auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType());
    auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst(
        (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer());

    IRInst* diffBase = nullptr;
    if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase))
    {
        if (diffBase)
        {
            List<IRInst*> args;
            for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
            {
                args.add(primalSpecialize->getArg(i));
            }
            auto diffSpecialize = builder->emitSpecializeInst(
                builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
            return InstPair(primalSpecialize, diffSpecialize);
        }
        else
        {
            return InstPair(primalSpecialize, nullptr);
        }
    }

    auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));

    // Right now we don't support transcribing a differentiable callee that is a specialize of a interface lookup
    // (calling differentiable generic interface method). To support it, we need to recursively transcribe the
    // specialization base here.

    if (!genericInnerVal)
        return InstPair(primalSpecialize, nullptr);

    // Look for an IRForwardDerivativeDecoration on the specialize inst.
    // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
    // can be specialized, so we put a decoration on the IRSpecialize)
    //
    if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
    {
        auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();

        // Make sure this isn't itself a specialize .
        SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));

        auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>();
        SLANG_RELEASE_ASSERT(derivativeDecoration);

        return InstPair(primalSpecialize, jvpFunc);
    }
    else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>())
    {
        diffBase = derivativeDecoration->getForwardDerivativeFunc();
        List<IRInst*> args;
        for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
        {
            args.add(primalSpecialize->getArg(i));
        }

        // A `ForwardDerivative` decoration on an inner func of a generic should always be a `specialize`.
        auto diffBaseSpecialize = as<IRSpecialize>(diffBase);
        SLANG_RELEASE_ASSERT(diffBaseSpecialize);

        // Note: this assumes that the generic arguments to specialize the derivative is the same as the
        // generic args to specialize the primal function. This is true for all of our stdlib functions,
        // but we may need to rely on more general substitution logic here.
        auto diffSpecialize = builder->emitSpecializeInst(
            builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer());
        return InstPair(primalSpecialize, diffSpecialize);
    }
    else if (_isDifferentiableFunc(genericInnerVal) || as<IRFuncType>(genericInnerVal))
    {
        List<IRInst*> args;
        for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
        {
            args.add(primalSpecialize->getArg(i));
        }
        diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
        auto diffSpecialize = builder->emitSpecializeInst(
            builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
        return InstPair(primalSpecialize, diffSpecialize);
    }
    return InstPair(primalSpecialize, nullptr);
}

InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
{
    SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst));

    IRInst* origBase = originalInst->getOperand(0);
    auto primalBase = findOrTranscribePrimalInst(builder, origBase);
    auto field = originalInst->getOperand(1);
    auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
    auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());

    IRInst* primalOperands[] = { primalBase, field };
    IRInst* primalFieldExtract = builder->emitIntrinsicInst(
        primalType,
        originalInst->getOp(),
        2,
        primalOperands);

    if (!derivativeRefDecor)
    {
        return InstPair(primalFieldExtract, nullptr);
    }

    IRInst* diffFieldExtract = nullptr;

    if (auto diffType = differentiateType(builder, originalInst->getDataType()))
    {
        if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
        {
            IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() };
            diffFieldExtract = builder->emitIntrinsicInst(
                diffType,
                originalInst->getOp(),
                2,
                diffOperands);
        }
    }
    return InstPair(primalFieldExtract, diffFieldExtract);
}

InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
{
    SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));

    IRInst* origBase = origGetElementPtr->getOperand(0);
    auto primalBase = findOrTranscribePrimalInst(builder, origBase);
    auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1));

    auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origGetElementPtr->getDataType());

    IRInst* primalOperands[] = {primalBase, primalIndex};
    IRInst* primalGetElementPtr = builder->emitIntrinsicInst(
        primalType,
        origGetElementPtr->getOp(),
        2,
        primalOperands);

    IRInst* diffGetElementPtr = nullptr;

    if (auto diffType = differentiateType(builder, origGetElementPtr->getDataType()))
    {
        if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
        {
            IRInst* diffOperands[] = {diffBase, primalIndex};
            diffGetElementPtr = builder->emitIntrinsicInst(
                diffType,
                origGetElementPtr->getOp(),
                2,
                diffOperands);
        }
    }

    return InstPair(primalGetElementPtr, diffGetElementPtr);
}

InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst)
{
    auto updateInst = as<IRUpdateElement>(originalInst);

    IRInst* origBase = updateInst->getOldValue();
    auto primalBase = findOrTranscribePrimalInst(builder, origBase);
    List<IRInst*> primalAccessChain;
    for (UInt i = 0; i < updateInst->getAccessKeyCount(); i++)
    {
        auto originalKey = updateInst->getAccessKey(i);
        auto primalKey = findOrTranscribePrimalInst(builder, originalKey);
        primalAccessChain.add(primalKey);
    }
    auto origVal = updateInst->getElementValue();
    auto primalVal = findOrTranscribePrimalInst(builder, origVal);

    IRInst* primalUpdateField =
        builder->emitUpdateElement(primalBase, primalAccessChain, primalVal);

    IRInst* diffUpdateElement = nullptr;
    List<IRInst*> diffAccessChain;
    for (auto key : primalAccessChain)
    {
        if (as<IRStructKey>(key))
        {
            auto decor = key->findDecoration<IRDerivativeMemberDecoration>();
            if (decor)
                diffAccessChain.add(decor->getDerivativeMemberStructKey());
            else
            {
                auto diffBase = findOrTranscribeDiffInst(builder, origBase);
                return InstPair(primalUpdateField, diffBase);
            }
        }
        else
        {
            diffAccessChain.add(key);
        }
    }
    if (const auto diffType = differentiateType(builder, originalInst->getDataType()))
    {
        auto diffBase = findOrTranscribeDiffInst(builder, origBase);
        if (!diffBase)
        {
            diffBase = getDifferentialZeroOfType(builder, origBase->getDataType());
        }
        if (auto diffVal = findOrTranscribeDiffInst(builder, origVal))
        {
            auto primalElementType = primalVal->getDataType();

            diffUpdateElement = builder->emitUpdateElement(
                diffBase, diffAccessChain, diffVal);
            builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
        }
        else
        {
            auto primalElementType = primalVal->getDataType();
            auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType);
            diffUpdateElement = builder->emitUpdateElement(
                diffBase, diffAccessChain, zeroElementDiff);
            builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
        }
    }
    return InstPair(primalUpdateField, diffUpdateElement);
}

InstPair ForwardDiffTranscriber::transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch)
{
    // Transcribe condition (primal only, conditions do not produce differentials)
    auto primalCondition = findOrTranscribePrimalInst(builder, origSwitch->getCondition());
    SLANG_ASSERT(primalCondition);
    
    // Transcribe 'default' block
    IRBlock* diffDefaultBlock = as<IRBlock>(
        findOrTranscribeDiffInst(builder, origSwitch->getDefaultLabel()));
    SLANG_ASSERT(diffDefaultBlock);

    // Transcribe 'default' block
    IRBlock* diffBreakBlock = as<IRBlock>(
        findOrTranscribeDiffInst(builder, origSwitch->getBreakLabel()));
    SLANG_ASSERT(diffBreakBlock);

    // Transcribe all other operands
    List<IRInst*> diffCaseValuesAndLabels;
    for (UIndex ii = 0; ii < origSwitch->getCaseCount(); ii ++)
    {
        auto primalCaseValue = findOrTranscribePrimalInst(builder, origSwitch->getCaseValue(ii));
        SLANG_ASSERT(primalCaseValue);

        auto diffCaseBlock = findOrTranscribeDiffInst(builder, origSwitch->getCaseLabel(ii));
        SLANG_ASSERT(diffCaseBlock);

        diffCaseValuesAndLabels.add(primalCaseValue);
        diffCaseValuesAndLabels.add(diffCaseBlock);
    }

    auto diffSwitchInst = builder->emitSwitch(
        primalCondition,
        diffBreakBlock,
        diffDefaultBlock, 
        diffCaseValuesAndLabels.getCount(),
        diffCaseValuesAndLabels.getBuffer());
    builder->markInstAsMixedDifferential(diffSwitchInst);

    return InstPair(diffSwitchInst, diffSwitchInst);
}

InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
{
    // IfElse Statements come with 4 blocks. We transcribe each block into it's
    // linear form, and then wire them up in the same way as the original if-else
    
    // Transcribe condition block
    auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition());
    SLANG_ASSERT(primalConditionBlock);

    // Transcribe 'true' block (condition block branches into this if true)
    auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock());
    SLANG_ASSERT(diffTrueBlock);

    // Transcribe 'false' block (condition block branches into this if true)
    auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
    SLANG_ASSERT(diffFalseBlock);

    // Transcribe 'after' block (true and false blocks branch into this)
    auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock());
    SLANG_ASSERT(diffAfterBlock);
    
    List<IRInst*> diffIfElseArgs;
    diffIfElseArgs.add(primalConditionBlock);
    diffIfElseArgs.add(diffTrueBlock);
    diffIfElseArgs.add(diffFalseBlock);
    diffIfElseArgs.add(diffAfterBlock);

    // If there are any other operands, use their primal versions.
    for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++)
    {
        auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii));
        diffIfElseArgs.add(primalOperand);
    }

    IRInst* diffIfElse = builder->emitIntrinsicInst(
        nullptr,
        kIROp_ifElse,
        diffIfElseArgs.getCount(),
        diffIfElseArgs.getBuffer());
    builder->markInstAsMixedDifferential(diffIfElse);

    return InstPair(diffIfElse, diffIfElse);
}

InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst)
{
    auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
    SLANG_ASSERT(primalVal);
    auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
    SLANG_ASSERT(diffPrimalVal);

    auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
    if (!primalDiffVal)
        primalDiffVal = getDifferentialZeroOfType(builder, origInst->getPrimalValue()->getDataType());
    SLANG_ASSERT(primalDiffVal);

    auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
    if (!diffDiffVal)
        diffDiffVal = getDifferentialZeroOfType(builder, origInst->getDifferentialValue()->getDataType());
    SLANG_ASSERT(diffDiffVal);

    auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType());
    auto diffPairType = findOrTranscribeDiffInst(builder, origInst->getFullType());
    auto primalPair = builder->emitMakeDifferentialPairUserCode(
        (IRType*)primalPairType, primalVal, diffPrimalVal);
    auto diffPair = builder->emitMakeDifferentialPairUserCode(
        (IRType*)diffPairType,
        primalDiffVal,
        diffDiffVal);
    return InstPair(primalPair, diffPair);
}

InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
{
    SLANG_ASSERT(
        origInst->getOp() == kIROp_DifferentialPairGetDifferentialUserCode ||
        origInst->getOp() == kIROp_DifferentialPairGetPrimalUserCode);

    auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0));
    SLANG_ASSERT(primalVal);

    auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0));
    SLANG_ASSERT(diffVal);

    auto primalType = findOrTranscribePrimalInst(builder, origInst->getFullType());

    auto primalResult = builder->emitIntrinsicInst((IRType*)primalType, origInst->getOp(), 1, &primalVal);

    auto diffValPairType = as<IRDifferentialPairUserCodeType>(diffVal->getDataType());
    IRInst* diffResultType = nullptr;
    if (origInst->getOp() == kIROp_DifferentialPairGetDifferentialUserCode)
        diffResultType = differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffValPairType);
    else
        diffResultType = diffValPairType->getValueType();
    auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal);
    return InstPair(primalResult, diffResult);
}

InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst)
{
    IRInst* origBase = origInst->getOperand(0);
    auto primalBase = findOrTranscribePrimalInst(builder, origBase);
    auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType());

    IRInst* primalResult = builder->emitIntrinsicInst(
        primalType,
        origInst->getOp(),
        1,
        &primalBase);

    IRInst* diffResult = nullptr;

    if (auto diffType = differentiateType(builder, origInst->getDataType()))
    {
        if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
        {
            diffResult = builder->emitIntrinsicInst(
                diffType,
                origInst->getOp(),
                1,
                &diffBase);
        }
    }
    return InstPair(primalResult, diffResult);
}

InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, IRMakeExistential* origMakeExistential)
{
    auto origBase = origMakeExistential->getWrappedValue();
    auto origWitnessTable = origMakeExistential->getWitnessTable();

    auto primalBase = findOrTranscribePrimalInst(builder, origBase);
    auto primalWitnessTable = findOrTranscribePrimalInst(builder, origWitnessTable);
    auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origMakeExistential->getDataType());

    IRInst* primalResult = builder->emitMakeExistential(
        primalType,
        primalBase,
        primalWitnessTable);

    IRInst* diffResult = nullptr;

    auto primalInterfaceType = as<IRInterfaceType>(unwrapAttributedType(origMakeExistential->getDataType()));
    SLANG_RELEASE_ASSERT(primalInterfaceType);

    // If the interface type of the existential is differentiable, we emit a make existential
    // of IDifferentiable.Differential type and the witness table of the original type's conformance
    // to IDifferentiable.
    //
    if (auto differentialWitnessTable = differentiableTypeConformanceContext.tryExtractConformanceFromInterfaceType(
        builder, primalInterfaceType, (IRWitnessTable*)primalWitnessTable))
    {
        if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
        {
            auto differentialAssociatedType = differentiateType(builder, primalInterfaceType);
            SLANG_ASSERT(differentialAssociatedType);

            diffResult = builder->emitMakeExistential(
                differentialAssociatedType,
                diffBase,
                differentialWitnessTable);
        }
    }

    return InstPair(primalResult, diffResult);
}

InstPair ForwardDiffTranscriber::transcribeDefaultConstruct(IRBuilder* builder, IRInst* origInst)
{
    IRInst* primalConstruct = maybeCloneForPrimalInst(builder, origInst);

    IRInst* diffConstruct = nullptr;

    if (auto diffType = differentiateType(builder, origInst->getDataType()))
    {
        diffConstruct = builder->emitDefaultConstructRaw(diffType);
    }
    return InstPair(primalConstruct, diffConstruct);
}

InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst)
{
    auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType());

    List<IRInst*> primalArgs;
    for (UInt i = 0; i < origInst->getOperandCount(); i++)
    {
        auto primalArg = findOrTranscribePrimalInst(builder, origInst->getOperand(i));
        primalArgs.add(primalArg);
    }

    IRInst* primalResult = builder->emitIntrinsicInst(
        primalType,
        origInst->getOp(),
        primalArgs.getCount(),
        primalArgs.getBuffer());

    IRInst* diffResult = nullptr;

    if (auto diffType = differentiateType(builder, origInst->getDataType()))
    {
        List<IRInst*> diffArgs;
        for (UInt i = 0; i < origInst->getOperandCount(); i++)
        {
            auto arg = findOrTranscribeDiffInst(builder, origInst->getOperand(i));
            if (arg)
            {
                diffArgs.add(arg);
            }
            else if (i == 0)
            {
                // If we can't diff the first operand (base), abort now.
                break;
            }
        }
        if (diffArgs.getCount())
        {
            diffResult = builder->emitIntrinsicInst(
                diffType,
                origInst->getOp(),
                diffArgs.getCount(),
                diffArgs.getBuffer());
        }
    }
    return InstPair(primalResult, diffResult);
}

// Create an empty func to represent the transcribed func of `origFunc`.
InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
    if (auto fwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>())
    {
        // If we reach here, the function must have been used directly in a `call` inst, and therefore
        // can't be a generic.
        // Generic function are always referenced with `specialize` inst and the handling logic for
        // custom derivatives is implemented in `transcribeSpecialize`.
        SLANG_RELEASE_ASSERT(fwdDecor->getForwardDerivativeFunc()->getOp() == kIROp_Func);
        return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc());
    }
    
    auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);

    if (auto outerGen = findOuterGeneric(diffFunc))
    {
        IRBuilder subBuilder = *inBuilder;
        subBuilder.setInsertBefore(origFunc);
        auto specialized =
            specializeWithGeneric(subBuilder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
        subBuilder.addForwardDerivativeDecoration(origFunc, specialized);
    }
    else
    {
        inBuilder->addForwardDerivativeDecoration(origFunc, diffFunc);
    }

    inBuilder->addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);

    copyOriginalDecorations(origFunc, diffFunc);

    FuncBodyTranscriptionTask task;
    task.type = FuncBodyTranscriptionTaskType::Forward;
    task.originalFunc = origFunc;
    task.resultFunc = diffFunc;
    autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);

    return InstPair(origFunc, diffFunc);
}

IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc)
{
    IRBuilder builder = *inBuilder;

    maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);

    differentiableTypeConformanceContext.setFunc(origFunc);

    auto diffFunc = builder.createFunc();

    SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
    IRType* diffFuncType = this->differentiateFunctionType(
        &builder,
        origFunc,
        as<IRFuncType>(origFunc->getFullType()));
    diffFunc->setFullType(diffFuncType);

    if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
    {
        auto originalName = nameHint->getName();
        StringBuilder newNameSb;
        newNameSb << "s_fwd_" << originalName;
        builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
    }

    // Mark the generated derivative function itself as differentiable.
    builder.addForwardDifferentiableDecoration(diffFunc);
    if (isBackwardDifferentiableFunc(origFunc))
        builder.addBackwardDifferentiableDecoration(diffFunc);
    
    // Transfer checkpoint hint decorations
    copyCheckpointHints(&builder, origFunc, diffFunc);

    // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
    if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
    {
        cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule());
    }
    return diffFunc;
}

void ForwardDiffTranscriber::checkAutodiffInstDecorations(IRFunc* fwdFunc)
{
    for (auto block = fwdFunc->getFirstBlock(); block; block = block->getNextBlock())
    {
        for (auto inst = block->getFirstOrdinaryInst(); inst; inst = inst->getNextInst())
        {
            // TODO: Special case, not sure why these insts show up
            if (as<IRUndefined>(inst)) continue;

            List<IRDecoration*> decorations;
            for (auto decoration : inst->getDecorations())
            {
                if (as<IRAutodiffInstDecoration>(decoration))
                    decorations.add(decoration);
            }

            // Must have _exactly_ one autodiff tag.
            SLANG_ASSERT(decorations.getCount() == 1);
        }
    }
}

void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
{
    IRBuilder builder(module);
    auto firstBlock = func->getFirstBlock();
    builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());

    OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar;
    List<IRParam*> params;
    for (auto param : firstBlock->getParams())
    {
        if (const auto ptrType = as<IROutTypeBase>(param->getDataType()))
        {
            params.add(param);
        }
    }

    for (auto param : params)
    {
        auto ptrType = as<IRPtrTypeBase>(param->getDataType());
        auto tempVar = builder.emitVar(ptrType->getValueType());
        param->replaceUsesWith(tempVar);
        mapParamToTempVar[param] = tempVar;
        if (ptrType->getOp() != kIROp_OutType)
        {
            builder.emitStore(tempVar, builder.emitLoad(param));
        }
        else
        {
            builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType()));
        }
    }

    for (auto block : func->getBlocks())
    {
        for (auto inst : block->getChildren())
        {
            if (inst->getOp() == kIROp_Return)
            {
                builder.setInsertBefore(inst);
                for (const auto& [param, var] : mapParamToTempVar)
                    builder.emitStore(param, builder.emitLoad(var));
            }
        }
    }
}

bool isLocalPointer(IRInst* ptrInst)
{
    // If it's not a local var or a function parameter, then it's probably 
    // referencing something outside the function scope.
    // 
    auto addr = getRootAddr(ptrInst);
    return as<IRVar>(addr) || as<IRParam, IRDynamicCastBehavior::NoUnwrap>(addr);
}

void lowerSwizzledStores(IRModule* module, IRFunc* func)
{
    List<IRInst*> instsToRemove;

    IRBuilder builder(module);
    for (auto block : func->getBlocks())
    {
        for (auto inst : block->getChildren())
        {
            if (auto swizzledStore = as<IRSwizzledStore>(inst))
            {
                if (!isLocalPointer(swizzledStore->getDest()))
                    continue;

                builder.setInsertBefore(inst);
                for (UIndex ii = 0; ii < swizzledStore->getElementCount(); ii++)
                {
                    auto indexVal = swizzledStore->getElementIndex(ii);
                    auto indexedPtr = builder.emitElementAddress(swizzledStore->getDest(), indexVal);
                    builder.emitStore(
                        indexedPtr, 
                        builder.emitElementExtract(
                            swizzledStore->getSource(), 
                            builder.getIntValue(builder.getIntType(), ii)));
                }
                instsToRemove.add(inst);
            }
        }
    }

    for (auto inst : instsToRemove)
    {
        inst->removeAndDeallocate();
    }
}

SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
{
    insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
    removeLinkageDecorations(func);
    
    performPreAutoDiffForceInlining(func);

    initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);

    lowerSwizzledStores(autoDiffSharedContext->moduleInst->getModule(), func);

    auto result = eliminateAddressInsts(func, sink);

    if (SLANG_SUCCEEDED(result))
    {
        disableIRValidationAtInsert();
        simplifyFunc(autoDiffSharedContext->targetRequest, func, IRSimplificationOptions::getDefault());
        enableIRValidationAtInsert();
    }
    return result;
}

// Transcribe a function definition.
InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
{
    if (primalFunc->findDecoration<IRTreatAsDifferentiableDecoration>())
    {
        // Generate a trivial implementation for [TreatAsDifferentiable] functions.
        generateTrivialFwdDiffFunc(primalFunc, diffFunc);
        return InstPair(primalFunc, diffFunc);
    }

    IRBuilder builder = *inBuilder;
    builder.setInsertBefore(primalFunc);

    // Create a clone for original func and run additional transformations on the clone.
    IRCloneEnv env;
    auto primalFuncClone = as<IRFunc>(cloneInst(&env, &builder, primalFunc));
    prepareFuncForForwardDiff(primalFuncClone);

    builder.setInsertInto(diffFunc);

    differentiableTypeConformanceContext.setFunc(primalFuncClone);
    
    mapInOutParamToWriteBackValue.clear();

    // Create and map blocks in diff func.
    for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock())
    {
        auto diffBlock = builder.emitBlock();
        mapPrimalInst(block, diffBlock);
        mapDifferentialInst(block, diffBlock);
    }

    // Now actually transcribe the content of each block.
    for (auto block = primalFuncClone->getFirstBlock(); block; block = block->getNextBlock())
        this->transcribeBlock(&builder, block);

    for (auto block : diffFunc->getBlocks())
    {
        for (auto inst : block->getChildren())
        {
            if (inst->getOp() == kIROp_Return)
            {
                // Insert write backs to mutable parameters before returning.
                builder.setInsertBefore(inst);
                for (auto& writeBack : mapInOutParamToWriteBackValue)
                {
                    auto param = writeBack.key;
                    auto primalVal = builder.emitLoad(writeBack.value.primal);
                    IRInst* valToStore = nullptr;
                    if (writeBack.value.differential)
                    {
                        auto diffVal = builder.emitLoad(writeBack.value.differential);
                        builder.markInstAsDifferential(diffVal, primalVal->getFullType());

                        valToStore = builder.emitMakeDifferentialPair(cast<IRPtrTypeBase>(param->getFullType())->getValueType(),
                            primalVal, diffVal);
                        builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType());
                    }
                    else
                    {
                        valToStore = builder.emitLoad(writeBack.value.primal);
                    }

                    auto storeInst = builder.emitStore(param, valToStore);

                    if (writeBack.value.differential)
                    {
                        builder.markInstAsMixedDifferential(storeInst, valToStore->getFullType());
                    }
                }
            }
        }
    }
    
#if _DEBUG
    checkAutodiffInstDecorations(diffFunc);
#endif

    return InstPair(primalFunc, diffFunc);
}

InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst)
{
    // Handle common SSA-style operations
    switch (origInst->getOp())
    {
    case kIROp_Param:
        return transcribeParam(builder, as<IRParam>(origInst));

    case kIROp_Var:
        return transcribeVar(builder, as<IRVar>(origInst));

    case kIROp_Load:
        return transcribeLoad(builder, as<IRLoad>(origInst));

    case kIROp_Store:
        return transcribeStore(builder, as<IRStore>(origInst));

    case kIROp_Return:
        return transcribeReturn(builder, as<IRReturn>(origInst));

    case kIROp_Add:
    case kIROp_Mul:
    case kIROp_Sub:
    case kIROp_Div:
        return transcribeBinaryArith(builder, origInst);
    
    case kIROp_Less:
    case kIROp_Greater:
    case kIROp_And:
    case kIROp_Or:
    case kIROp_Geq:
    case kIROp_Leq:
    case kIROp_Eql:
    case kIROp_Neq:
        return transcribeBinaryLogic(builder, origInst);
    
    case kIROp_Select:
        return transcribeSelect(builder, origInst);

    case kIROp_MakeVector:
    case kIROp_MakeMatrix:
    case kIROp_MakeMatrixFromScalar:
    case kIROp_MatrixReshape:
    case kIROp_IntCast:
    case kIROp_FloatCast:
    case kIROp_MakeVectorFromScalar:
    case kIROp_MakeArray:
    case kIROp_MakeArrayFromElement:
        return transcribeConstruct(builder, origInst);
    case kIROp_MakeStruct:
        return transcribeMakeStruct(builder, origInst);

    case kIROp_LookupWitness:
        return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));

    case kIROp_Call:
        return transcribeCall(builder, as<IRCall>(origInst));
    
    case kIROp_swizzle:
        return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
    
    case kIROp_MakeTuple:
    case kIROp_Neg:
        return transcribeByPassthrough(builder, origInst);

    case kIROp_UpdateElement:
        return transcribeUpdateElement(builder, origInst);

    case kIROp_unconditionalBranch:
    case kIROp_loop:
        return transcribeControlFlow(builder, origInst);

    case kIROp_FloatLit:
    case kIROp_IntLit:
    case kIROp_VoidLit:
        return transcribeConst(builder, origInst);

    case kIROp_Specialize:
        return transcribeSpecialize(builder, as<IRSpecialize>(origInst));

    case kIROp_FieldExtract:
    case kIROp_FieldAddress:
        return transcribeFieldExtract(builder, origInst);

    case kIROp_GetElement:
    case kIROp_GetElementPtr:
        return transcribeGetElement(builder, origInst);

    case kIROp_ifElse:
        return transcribeIfElse(builder, as<IRIfElse>(origInst));
    
    case kIROp_Switch:
        return transcribeSwitch(builder, as<IRSwitch>(origInst));

    case kIROp_MakeDifferentialPairUserCode:
        return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPairUserCode>(origInst));

    case kIROp_DifferentialPairGetPrimalUserCode:
    case kIROp_DifferentialPairGetDifferentialUserCode:
        return transcribeDifferentialPairGetElement(builder, origInst);

    case kIROp_ExtractExistentialValue:
        return transcribeSingleOperandInst(builder, origInst);

    case kIROp_PackAnyValue:
        return transcribeSingleOperandInst(builder, origInst);
        
    case kIROp_MakeExistential:
        return transcribeMakeExistential(builder, as<IRMakeExistential>(origInst));

    case kIROp_ExtractExistentialType:
    {
        IRInst* witnessTable;
        auto diffType = differentiateExtractExistentialType(
                builder, as<IRExtractExistentialType>(origInst), witnessTable);

        // Mark types as primal since they are not transposable.
        if (diffType)
            builder->markInstAsPrimal(diffType);

        return InstPair(
            maybeCloneForPrimalInst(builder, origInst),
            diffType);
    }
    case kIROp_ExtractExistentialWitnessTable:
        return transcribeExtractExistentialWitnessTable(builder, origInst);

    case kIROp_WrapExistential:
        return transcribeWrapExistential(builder, origInst);

    case kIROp_DefaultConstruct:
        return transcribeDefaultConstruct(builder, origInst);

    case kIROp_undefined:
        return transcribeUndefined(builder, origInst);
    
    case kIROp_Reinterpret:
        return transcribeReinterpret(builder, origInst);

        // Differentiable insts that should have been lowered in a previous pass.
    case kIROp_SwizzledStore: 
    {
        // If we have a non-null dest ptr, then we error out because something went wrong
        // when lowering swizzle-stores to regular stores
        //
        auto swizzledStore = as<IRSwizzledStore>(origInst);
        SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr);
        return transcribeNonDiffInst(builder, swizzledStore);
    }
        // Known non-differentiable insts.
    case kIROp_Not:
    case kIROp_BitAnd:
    case kIROp_BitNot:
    case kIROp_BitXor:
    case kIROp_BitOr:
    case kIROp_BitCast:
    case kIROp_Lsh:
    case kIROp_Rsh:
    case kIROp_IRem:
    case kIROp_ByteAddressBufferLoad:
    case kIROp_ByteAddressBufferStore:
    case kIROp_StructuredBufferLoad:
    case kIROp_RWStructuredBufferLoad:
    case kIROp_RWStructuredBufferLoadStatus:
    case kIROp_RWStructuredBufferStore:
    case kIROp_RWStructuredBufferGetElementPtr:
    case kIROp_IsType:
    case kIROp_ImageSubscript:
    case kIROp_ImageLoad:
    case kIROp_ImageStore:
    case kIROp_UnpackAnyValue:
    case kIROp_GetNativePtr:
    case kIROp_CastIntToFloat:
    case kIROp_CastFloatToInt:
    case kIROp_DetachDerivative:
    case kIROp_GetSequentialID:
    case kIROp_GetStringHash:
        return transcribeNonDiffInst(builder, origInst);

        // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
        // so we treat this inst as non differentiable.
        // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
        // 
        // However, we can't skip this instruction since it also produces a _type_ which may be used by
        // other differentiable instructions. Therefore, we'll create another existential object but with
        // a dzero() for it's value.
        // 
    case kIROp_CreateExistentialObject:
        return transcribeNonDiffInst(builder, origInst);

    case kIROp_StructKey:
        return InstPair(origInst, nullptr);
        
    case kIROp_Unreachable:
    {
        auto unreachInst = builder->emitUnreachable();
        builder->markInstAsMixedDifferential(unreachInst);
        return InstPair(unreachInst, nullptr);
    }

    case kIROp_MakeExistentialWithRTTI:
        SLANG_UNEXPECTED("MakeExistentialWithRTTI inst is not expected in autodiff pass.");
        break;
    }

    return InstPair(nullptr, nullptr);
}

String ForwardDiffTranscriber::makeDiffPairName(IRInst* origVar)
{
    if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
    {
        return ("dp" + String(namehintDecoration->getName()));
    }

    return String("");
}

InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
{
    SLANG_UNUSED(primalType);

    if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origParam->getFullType()))
    {
        IRInst* diffPairParam = builder->emitParam(diffPairType);

        auto diffPairVarName = makeDiffPairName(origParam);
        if (diffPairVarName.getLength() > 0)
            builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());

        SLANG_ASSERT(diffPairParam);

        if (auto pairType = as<IRDifferentialPairType>(diffPairType))
        {
            return InstPair(
                builder->emitDifferentialPairGetPrimal(diffPairParam),
                builder->emitDifferentialPairGetDifferential(
                    (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, pairType),
                    diffPairParam));
        }
        else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
        {
            auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
            // Make a local copy of the parameter for primal and diff parts.
            auto primal = builder->emitVar(ptrInnerPairType->getValueType());

            auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType());
            auto diff = builder->emitVar(diffType);
            builder->markInstAsDifferential(
                diff, builder->getPtrType(ptrInnerPairType->getValueType()));

            IRInst* primalInitVal = nullptr;
            IRInst* diffInitVal = nullptr;
            if (as<IROutType>(diffPairType))
            {
                primalInitVal = builder->emitDefaultConstruct(ptrInnerPairType->getValueType());
                diffInitVal = builder->emitDefaultConstructRaw(diffType);
            }
            else
            {
                auto initVal = builder->emitLoad(diffPairParam);
                builder->markInstAsMixedDifferential(initVal, ptrInnerPairType);

                primalInitVal = builder->emitDifferentialPairGetPrimal(initVal);
                diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal);
            }
            builder->markInstAsDifferential(diffInitVal, ptrInnerPairType->getValueType());
            
            builder->emitStore(primal, primalInitVal);

            auto diffStore = builder->emitStore(diff, diffInitVal);
            builder->markInstAsDifferential(diffStore, ptrInnerPairType->getValueType());

            mapInOutParamToWriteBackValue[diffPairParam] = InstPair(primal, diff);
            return InstPair(primal, diff);
        }
    }

    auto primalInst = cloneInst(&cloneEnv, builder, origParam);
    if (auto primalParam = as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalInst))
    {
        SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
        primalParam->removeFromParent();
        builder->getInsertLoc().getBlock()->addParam(primalParam);
    }
    return InstPair(primalInst, nullptr);
}

}
back to top