https://github.com/shader-slang/slang
Raw File
Tip revision: 01efe34dbef2be952298075abd8d36cc67ac9f4e authored by Yong He on 04 March 2024, 21:14:21 UTC
Add `IGlobalSession::getSessionDescDigest`. (#3669)
Tip revision: 01efe34
slang-ir-autodiff-transpose.h
// slang-ir-autodiff-transpose.h
#pragma once

#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-compiler.h"

#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-cfg-norm.h"
#include "slang-ir-autodiff-primal-hoist.h"
#include "slang-ir-dominators.h"

namespace Slang
{

struct DiffTransposePass
{
    
    struct RevGradient
    {
        enum Flavor 
        {
            Simple,
            Swizzle,
            GetElement,
            GetDifferential,
            FieldExtract,
            DifferentialPairGetElementUserCode,
            Invalid
        };

        RevGradient() :
            flavor(Flavor::Invalid), targetInst(nullptr), revGradInst(nullptr), fwdGradInst(nullptr)
        { }
        
        RevGradient(Flavor flavor, IRInst* targetInst, IRInst* revGradInst, IRInst* fwdGradInst) : 
            flavor(flavor), targetInst(targetInst), revGradInst(revGradInst), fwdGradInst(fwdGradInst)
        { }

        RevGradient(IRInst* targetInst, IRInst* revGradInst, IRInst* fwdGradInst) : 
            flavor(Flavor::Simple), targetInst(targetInst), revGradInst(revGradInst), fwdGradInst(fwdGradInst)
        { }

        bool operator==(const RevGradient& other) const
        {
            return (other.targetInst == targetInst) && 
                (other.revGradInst == revGradInst) && 
                (other.fwdGradInst == fwdGradInst) &&
                (other.flavor == flavor);
        }
        
        IRInst* targetInst;
        IRInst* revGradInst;
        IRInst* fwdGradInst;

        Flavor flavor;
    };

    DiffTransposePass(AutoDiffSharedContext* autodiffContext) : 
        autodiffContext(autodiffContext), pairBuilder(autodiffContext), diffTypeContext(autodiffContext)
    { }

    struct TranspositionResult
    {
        // Holds a set of pairs of 
        // (original-inst, inst-to-accumulate-for-orig-inst)
        List<RevGradient> revPairs;

        TranspositionResult()
        { }

        TranspositionResult(List<RevGradient> revPairs) : revPairs(revPairs)
        { }
    };

    struct FuncTranspositionInfo
    {
        // Inst that represents the reverse-mode derivative
        // of the *output* of the function.
        // 
        IRInst* dOutInst;
    };

    struct PendingBlockTerminatorEntry
    {
        IRBlock* fwdBlock;
        List<IRInst*> phiGrads;

        PendingBlockTerminatorEntry() : fwdBlock(nullptr)
        {}

        PendingBlockTerminatorEntry(IRBlock* fwdBlock, List<IRInst*> phiGrads) : 
            fwdBlock(fwdBlock), phiGrads(phiGrads)
        {}
    };

    bool isBlockLastInRegion(IRBlock* block, List<IRBlock*> endBlocks)
    {
        if (auto branchInst = as<IRUnconditionalBranch>(block->getTerminator()))
        {
            if (endBlocks.contains(branchInst->getTargetBlock()))
                return true;
            else
                return false;
        }
        else if (as<IRReturn>(block->getTerminator()))
        {
            return true;
        }

        return false;
    }

    List<IRInst*> getPhiGrads(IRBlock* block)
    {
        if (!phiGradsMap.containsKey(block))
            return List<IRInst*>();
        
        return phiGradsMap[block];
    }

    struct RegionEntryPoint
    {
        IRBlock* revEntry;
        IRBlock* fwdEndPoint;
        bool isTrivial;

        RegionEntryPoint(IRBlock* revEntry, IRBlock* fwdEndPoint) :
            revEntry(revEntry),
            fwdEndPoint(fwdEndPoint),
            isTrivial(false)
        { }

        RegionEntryPoint(IRBlock* revEntry, IRBlock* fwdEndPoint, bool isTrivial) :
            revEntry(revEntry),
            fwdEndPoint(fwdEndPoint),
            isTrivial(isTrivial)
        { }
    };

    IRBlock* getUniquePredecessor(IRBlock* block)
    {
        HashSet<IRBlock*> predecessorSet;
        for (auto predecessor : block->getPredecessors())
            predecessorSet.add(predecessor);
        
        SLANG_ASSERT(predecessorSet.getCount() == 1);

        return (*predecessorSet.begin());
    }

    RegionEntryPoint reverseCFGRegion(IRBlock* block, List<IRBlock*> endBlocks)
    {
        IRBlock* revBlock = revBlockMap[block];

        if (endBlocks.contains(block))
        {
            return RegionEntryPoint(revBlock, block, true);
        }

        // We shouldn't already have a terminator for this block
        SLANG_ASSERT(revBlock->getTerminator() == nullptr);

        IRBuilder builder(autodiffContext->moduleInst->getModule());

        auto currentBlock = block;
        while (!isBlockLastInRegion(currentBlock, endBlocks))
        {
            auto terminator = currentBlock->getTerminator();
            switch(terminator->getOp())
            {
                case kIROp_Return:
                    return RegionEntryPoint(revBlockMap[currentBlock], nullptr);

                case kIROp_unconditionalBranch:
                {
                    auto branchInst = as<IRUnconditionalBranch>(terminator);
                    auto nextBlock = as<IRBlock>(branchInst->getTargetBlock());
                    IRBlock* nextRevBlock = revBlockMap[nextBlock];
                    IRBlock* currRevBlock = revBlockMap[currentBlock];

                    SLANG_ASSERT(nextRevBlock->getTerminator() == nullptr);
                    builder.setInsertInto(nextRevBlock);

                    builder.emitBranch(currRevBlock,
                        getPhiGrads(nextBlock).getCount(),
                        getPhiGrads(nextBlock).getBuffer());
                    

                    currentBlock = nextBlock;
                    break;
                }

                case kIROp_ifElse:
                {
                    auto ifElse = as<IRIfElse>(terminator);
                    
                    auto trueBlock = ifElse->getTrueBlock();
                    auto falseBlock = ifElse->getFalseBlock();
                    auto afterBlock = ifElse->getAfterBlock();

                    auto revTrueRegionInfo = reverseCFGRegion(
                        trueBlock,
                        List<IRBlock*>(afterBlock));
                    auto revFalseRegionInfo = reverseCFGRegion(
                        falseBlock,
                        List<IRBlock*>(afterBlock));
                    //bool isTrueTrivial = (trueBlock == afterBlock);
                    //bool isFalseTrivial = (falseBlock == afterBlock);

                    IRBlock* revCondBlock = revBlockMap[afterBlock];
                    SLANG_ASSERT(revCondBlock->getTerminator() == nullptr);


                    IRBlock* revTrueEntryBlock = revTrueRegionInfo.revEntry;
                    IRBlock* revFalseEntryBlock = revFalseRegionInfo.revEntry;

                    IRBlock* revTrueExitBlock = revBlockMap[trueBlock];
                    IRBlock* revFalseExitBlock = revBlockMap[falseBlock];

                    auto phiGrads = getPhiGrads(afterBlock);
                    if (phiGrads.getCount() > 0)
                    {
                        revTrueEntryBlock = insertPhiBlockBefore(revTrueEntryBlock, phiGrads);
                        revFalseEntryBlock = insertPhiBlockBefore(revFalseEntryBlock, phiGrads);
                    }

                    IRBlock* revAfterBlock = revBlockMap[currentBlock];
                    
                    builder.setInsertInto(revCondBlock);

                    builder.emitIfElse(
                        ifElse->getCondition(),
                        revTrueEntryBlock,
                        revFalseEntryBlock,
                        revAfterBlock);
                    
                    if (!revTrueRegionInfo.isTrivial)
                    {
                        builder.setInsertInto(revTrueExitBlock);
                        SLANG_ASSERT(revTrueExitBlock->getTerminator() == nullptr);
                        builder.emitBranch(
                            revAfterBlock,
                            getPhiGrads(trueBlock).getCount(),
                            getPhiGrads(trueBlock).getBuffer());
                    }

                    if (!revFalseRegionInfo.isTrivial)
                    {
                        builder.setInsertInto(revFalseExitBlock);
                        SLANG_ASSERT(revFalseExitBlock->getTerminator() == nullptr);
                        builder.emitBranch(
                            revAfterBlock,
                            getPhiGrads(falseBlock).getCount(),
                            getPhiGrads(falseBlock).getBuffer());
                    }

                    currentBlock = afterBlock;
                    break;
                }

                case kIROp_loop:
                {
                    auto loop = as<IRLoop>(terminator);
                    
                    auto firstLoopBlock = loop->getTargetBlock();
                    auto breakBlock = loop->getBreakBlock();

                    auto condBlock = getOrCreateTopLevelCondition(loop);

                    auto ifElse = as<IRIfElse>(condBlock->getTerminator());

                    auto trueBlock = ifElse->getTrueBlock();
                    auto falseBlock = ifElse->getFalseBlock();

                    auto trueRegionInfo = reverseCFGRegion(
                        trueBlock,
                        List<IRBlock*>(breakBlock, condBlock));

                    auto falseRegionInfo = reverseCFGRegion(
                        falseBlock,
                        List<IRBlock*>(breakBlock, condBlock));

                    auto preCondRegionInfo = reverseCFGRegion(
                        firstLoopBlock,
                        List<IRBlock*>(condBlock));

                    // assume loop[next] -> cond can be a region and reverse it.
                    // assume cond[false] -> break can be a region and reverse it.
                    // assume cond[true] -> cond can be a region and reverse it.
                    // rev-loop = rev[break]
                    // rev-cond = rev[cond]
                    // rev-cond[true] -> entry of (cond[true] -> cond)
                    // rev-cond[false] -> entry of (loop[next] -> cond)
                    // exit of (cond[false]->break) branches into rev-cond
                    // rev-loop[next] -> entry of (cond[false] -> break)
                    // exit of (cond[true] -> cond) branches into rev-cond
                    // exit of (loop[next] -> cond) branches into rev[loop] (rev-break)

                    // For now, we'll assume the loop is always on the 'true' side
                    // If this assert fails, add in the case where the loop
                    // may be on the 'false' side.
                    // 
                    SLANG_RELEASE_ASSERT(trueRegionInfo.fwdEndPoint == condBlock);

                    auto revTrueBlock = trueRegionInfo.revEntry;
                    auto revFalseBlock = (preCondRegionInfo.isTrivial) ? 
                        revBlockMap[currentBlock] : preCondRegionInfo.revEntry;
                    
                    // The block that will become target of the new loop inst
                    // (the old false-region) This _could_ be the condition itself
                    // 
                    IRBlock* revPreCondBlock = (falseRegionInfo.isTrivial) ? 
                        revBlockMap[condBlock] : falseRegionInfo.revEntry;
                    
                    // Old cond block remains new cond block.
                    IRBlock* revCondBlock = revBlockMap[condBlock];

                    // Old cond block becomes new pre-break block.
                    IRBlock* revBreakBlock = revBlockMap[currentBlock];

                    // Old true-side starting block becomes loop end block.
                    IRBlock* revLoopEndBlock = revBlockMap[trueBlock];
                    builder.setInsertInto(revLoopEndBlock);
                    builder.emitBranch(
                        revCondBlock,
                        getPhiGrads(trueBlock).getCount(),
                        getPhiGrads(trueBlock).getBuffer());
                    

                    IRBlock* revBreakRegionExitBlock = revBlockMap[firstLoopBlock];
                    if (!preCondRegionInfo.isTrivial)
                    {
                        builder.setInsertInto(revBreakRegionExitBlock);
                        builder.emitBranch(
                            revBreakBlock,
                            getPhiGrads(firstLoopBlock).getCount(),
                            getPhiGrads(firstLoopBlock).getBuffer());
                    }

                    auto phiGrads = getPhiGrads(condBlock);
                    if (phiGrads.getCount() > 0)
                    {
                        revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiGrads);
                        revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiGrads);
                    }

                    // Emit condition into the new cond block.
                    builder.setInsertInto(revCondBlock);

                    builder.emitIfElse(
                        ifElse->getCondition(),
                        revTrueBlock,
                        revFalseBlock,
                        revTrueBlock);

                    auto loopParentBlockDiffDecor = loop->getParent()->findDecoration<IRDifferentialInstDecoration>();
                    SLANG_RELEASE_ASSERT(loopParentBlockDiffDecor);
                    auto primalBlock = as<IRBlock>(loopParentBlockDiffDecor->getPrimalInst());
                    auto primalLoop = as<IRLoop>(primalBlock->getTerminator());
                    SLANG_RELEASE_ASSERT(primalLoop);

                    // Old false-side starting block becomes end block 
                    // for the new pre-cond region (which could be empty)
                    // 
                    if (!falseRegionInfo.isTrivial)
                    {
                        IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
                        builder.setInsertInto(revPreCondEndBlock);
                        auto revLoop = builder.emitLoop(
                            revCondBlock,
                            revBreakBlock,
                            revLoopEndBlock,
                            getPhiGrads(falseBlock).getCount(),
                            getPhiGrads(falseBlock).getBuffer());
                        loop->transferDecorationsTo(revLoop);
                        builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop);

                        auto revLoopStartBlock = revBlockMap[breakBlock];
                        builder.setInsertInto(revLoopStartBlock);
                        builder.emitBranch(
                            revPreCondBlock,
                            getPhiGrads(breakBlock).getCount(),
                            getPhiGrads(breakBlock).getBuffer());
                    }
                    else
                    {
                        // Emit loop into rev-version of the break block.
                        auto revLoopBlock = revBlockMap[breakBlock];
                        builder.setInsertInto(revLoopBlock);
                        auto revLoop = builder.emitLoop(
                            revPreCondBlock,
                            revBreakBlock,
                            revLoopEndBlock,
                            getPhiGrads(breakBlock).getCount(),
                            getPhiGrads(breakBlock).getBuffer());
                        loop->transferDecorationsTo(revLoop);
                        builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop);
                    }

                    currentBlock = breakBlock;
                    break;
                }

                case kIROp_Switch:
                {
                    auto switchInst = as<IRSwitch>(terminator);

                    auto breakBlock = switchInst->getBreakLabel();

                    IRBlock* revBreakBlock = revBlockMap[currentBlock];

                    // Reverse each case label
                    List<IRInst*> reverseSwitchArgs;
                    Dictionary<IRBlock*, IRBlock*> reverseLabelEntryBlocks;

                    for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++)
                    {
                        reverseSwitchArgs.add(switchInst->getCaseValue(ii));

                        auto caseLabel = switchInst->getCaseLabel(ii);
                        if (!reverseLabelEntryBlocks.containsKey(caseLabel))
                        {
                            auto labelRegionInfo = reverseCFGRegion(
                                caseLabel,
                                List<IRBlock*>(breakBlock));

                            // Handle this case eventually.
                            SLANG_ASSERT(!labelRegionInfo.isTrivial);

                            // Wire the exit to the break block
                            IRBlock* revLabelExit = revBlockMap[caseLabel];
                            SLANG_ASSERT(revLabelExit->getTerminator() == nullptr);

                            builder.setInsertInto(revLabelExit);
                            builder.emitBranch(revBreakBlock);
                            
                            reverseLabelEntryBlocks[caseLabel] = labelRegionInfo.revEntry;
                            reverseSwitchArgs.add(labelRegionInfo.revEntry);
                        }
                        else
                        {
                            reverseSwitchArgs.add(reverseLabelEntryBlocks[caseLabel]);
                        }
                    }
                    
                    auto defaultRegionInfo = reverseCFGRegion(
                        switchInst->getDefaultLabel(),
                        List<IRBlock*>(breakBlock));
                    SLANG_ASSERT(!defaultRegionInfo.isTrivial);
                    
                    auto revDefaultRegionEntry = defaultRegionInfo.revEntry;

                    builder.setInsertInto(revBlockMap[switchInst->getDefaultLabel()]);
                    builder.emitBranch(revBreakBlock);

                    auto phiGrads = getPhiGrads(breakBlock);
                    if (phiGrads.getCount() > 0)
                    {
                        for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++)
                        {
                            reverseSwitchArgs[ii * 2 + 1] =
                                insertPhiBlockBefore(as<IRBlock>(reverseSwitchArgs[ii * 2 + 1]), phiGrads);
                        }
                        revDefaultRegionEntry =
                                insertPhiBlockBefore(as<IRBlock>(revDefaultRegionEntry), phiGrads);
                    }

                    auto revSwitchBlock = revBlockMap[breakBlock];

                    builder.setInsertInto(revSwitchBlock);

                    builder.emitSwitch(
                        switchInst->getCondition(),
                        revBreakBlock,
                        revDefaultRegionEntry,
                        reverseSwitchArgs.getCount(),
                        reverseSwitchArgs.getBuffer());
                    
                    currentBlock = breakBlock;
                    break;
                }

            }
        }

        if (auto branchInst = as<IRUnconditionalBranch>(currentBlock->getTerminator()))
        {
            return RegionEntryPoint(
                revBlockMap[currentBlock],
                branchInst->getTargetBlock(),
                false);
        }
        else if (const auto returnInst = as<IRReturn>(currentBlock->getTerminator()))
        {
            return RegionEntryPoint(
                revBlockMap[currentBlock],
                nullptr,
                true);
        }
        else
        {
            // Regions should _really_ not end on a conditional branch (I think)
            SLANG_UNEXPECTED("Unexpected: Region ended on a conditional branch");
        }
    }

    void transposeDiffBlocksInFunc(
        IRFunc* revDiffFunc,
        FuncTranspositionInfo transposeInfo)
    {
        // TODO (sai): We really to make this method stateless 
        // (i.e. not store per-func info in 'this')
        // since it is reused for every reverse-mode call.
        //
        // Grab all differentiable type information.
        diffTypeContext.setFunc(revDiffFunc);
        
        // Note down terminal primal and terminal differential blocks
        // since we need to link them up at the end.
        auto terminalPrimalBlocks = getTerminalPrimalBlocks(revDiffFunc);
        auto terminalDiffBlocks = getTerminalDiffBlocks(revDiffFunc);

        // Traverse all instructions/blocks in reverse (starting from the terminator inst)
        // look for insts/blocks marked with IRDifferentialInstDecoration,
        // and transpose them in the revDiffFunc.
        //
        IRBuilder builder(autodiffContext->moduleInst);

        // Insert after the last block.
        builder.setInsertInto(revDiffFunc);

        List<IRBlock*> workList;

        // Build initial list of blocks to process by checking if they're differential blocks.
        List<IRBlock*> traverseWorkList;
        HashSet<IRBlock*> traverseSet;
        traverseWorkList.add(revDiffFunc->getFirstBlock());

        traverseSet.add(revDiffFunc->getFirstBlock());
        for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
        {
            if (!isDifferentialInst(block))
            {
                // Skip blocks that aren't computing differentials.
                // At this stage we should have 'unzipped' the function
                // into blocks that either entirely deal with primal insts,
                // or entirely with differential insts.
                continue;
            }

            workList.add(block);
        }

        if (!workList.getCount())
            return;

        // Reverse the order of the blocks.
        workList.reverse();
        
        // Emit empty rev-mode blocks for every fwd-mode block.
        for (auto block : workList)
        {
            auto revBlock = builder.emitBlock();
            revBlockMap[block] = revBlock;
            if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>())
                builder.markInstAsDifferential(revBlockMap[block], builder.getBasicBlockType(), diffDecor->getPrimalInst());
        }

        // Keep track of first diff block, since this is where 
        // we'll emit temporary vars to hold per-block derivatives.
        // 
        auto firstRevDiffBlock = revBlockMap.getValue(terminalDiffBlocks[0]);
        firstRevDiffBlockMap[revDiffFunc] = firstRevDiffBlock;

        // Move all diff vars to first block, and initialize them with zero.
        builder.setInsertInto(firstRevDiffBlock);
        for (auto block : workList)
        {
            for (auto inst = block->getFirstInst(); inst;)
            {
                auto nextInst = inst->getNextInst();
                if (auto varInst = as<IRVar>(inst))
                {
                    if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst))
                    {
                        if (auto ptrPrimalType = as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst)))
                        {
                            varInst->insertAtEnd(firstRevDiffBlock);

                            auto dzero = emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType());
                            builder.emitStore(varInst, dzero);
                        }
                        else
                        {
                            SLANG_UNEXPECTED("Expected an pointer-typed differential variable.");
                        }
                    }
                }
                inst = nextInst;
            }
        }

        // Make a temporary block to hold inverted insts.
        tempInvBlock = builder.createBlock();

        for (auto block : workList)
        {
            // Set dOutParameter as the transpose gradient for the return inst, if any.
            if (transposeInfo.dOutInst)
            {
                if (auto returnInst = as<IRReturn>(block->getTerminator()))
                {
                    this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr));
                }
            }

            IRBlock* revBlock = revBlockMap[block];
            this->transposeBlock(block, revBlock);
        }

        // At this point all insts have been transposed, but the blocks
        // have no control flow.
        // reverseCFG will use fwd-mode blocks as reference, and 
        // wire the corresponding rev-mode blocks in reverse.
        // 
        auto branchInst = as<IRUnconditionalBranch>(terminalPrimalBlocks[0]->getTerminator());
        auto firstFwdDiffBlock = branchInst->getTargetBlock();
        reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>());

        // Link the last differential fwd-mode block (which will be the first
        // rev-mode block) as the successor to the last primal block.
        // We assume that the original function is in single-return form
        // So, there should be exactly 1 'last' block of each type.
        // 
        {
            SLANG_ASSERT(terminalPrimalBlocks.getCount() == 1);
            SLANG_ASSERT(terminalDiffBlocks.getCount() == 1);

            auto terminalPrimalBlock = terminalPrimalBlocks[0];
            auto firstRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]);

            auto returnDecoration = 
                terminalPrimalBlock->getTerminator()->findDecoration<IRBackwardDerivativePrimalReturnDecoration>();
            SLANG_ASSERT(returnDecoration);
            auto retVal = returnDecoration->getBackwardDerivativePrimalReturnValue();

            terminalPrimalBlock->getTerminator()->removeAndDeallocate();
            
            IRBuilder subBuilder = builder;
            subBuilder.setInsertInto(terminalPrimalBlock);

            // There should be no parameters in the first reverse-mode block.
            SLANG_ASSERT(firstRevBlock->getFirstParam() == nullptr);

            auto branch = subBuilder.emitBranch(firstRevBlock);

            subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
        }

        // At this point, the only block left without terminator insts
        // should be the last one. Add a void return to complete it.
        // 
        IRBlock* lastRevBlock = revBlockMap[firstFwdDiffBlock];
        SLANG_ASSERT(lastRevBlock->getTerminator() == nullptr);

        builder.setInsertInto(lastRevBlock);
        builder.emitReturn();

        // Remove fwd-mode blocks.
        for (auto block : workList)
        {
            block->removeAndDeallocate();
        }
    }

    IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst)
    {
        if (auto accVar = getOrCreateAccumulatorVar(fwdInst))
        {
            auto gradValue = builder->emitLoad(accVar);
            builder->emitStore(
                accVar,
                emitDZeroOfDiffInstType(
                    builder,
                    tryGetPrimalTypeFromDiffInst(fwdInst)));
            
            return gradValue;
        }
        else
        {
            return nullptr;
        }
    }

    // Fetch or create a gradient accumulator var
    // corresponding to a inst. These are used to
    // accumulate gradients across blocks.
    //
    IRVar* getOrCreateAccumulatorVar(IRInst* fwdInst)
    {
        // Check if we have a var already.
        if (revAccumulatorVarMap.containsKey(fwdInst))
            return revAccumulatorVarMap[fwdInst];
        
        IRBuilder tempVarBuilder(autodiffContext->moduleInst->getModule());
        
        IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(fwdInst->getParent()->getParent())];

        if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst())
            tempVarBuilder.setInsertBefore(firstInst);
        else
            tempVarBuilder.setInsertInto(firstDiffBlock);
        
        auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst);
        auto diffType = fwdInst->getDataType();

        auto zero = emitDZeroOfDiffInstType(&tempVarBuilder, primalType);

        // Emit a var in the top-level differential block to hold the gradient, 
        // and initialize it.
        auto tempRevVar = tempVarBuilder.emitVar(diffType);
        tempVarBuilder.emitStore(tempRevVar, zero);
        revAccumulatorVarMap[fwdInst] = tempRevVar;

        return tempRevVar;
    }

    bool isInstUsedOutsideParentBlock(IRInst* inst)
    {
        auto currBlock = inst->getParent();

        for (auto use = inst->firstUse; use; use = use->nextUse)
        {
            if (use->getUser()->getParent() != currBlock)
                return true;
        }

        return false;
    }

    void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock)
    {
        IRBuilder builder(autodiffContext->moduleInst);
 
        // Insert into our reverse block.
        builder.setInsertInto(revBlock);

        // Create an inverse builder to insert insts into the inv-block.
        IRBuilder invBuilder(autodiffContext->moduleInst);
        

        // Check if this block has any 'outputs' (in the form of phi args
        // sent to the successor block)
        // 
        if (auto branchInst = as<IRUnconditionalBranch>(fwdBlock->getTerminator()))
        {
            for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
            {
                auto arg = branchInst->getArg(ii);
                if (isDifferentialInst(arg))
                {
                    // If the arg is a differential, emit a parameter
                    // to accept it's reverse-mode differential as an input
                    // 

                    auto diffType = arg->getDataType();
                    auto revParam = builder.emitParam(diffType);

                    addRevGradientForFwdInst(
                        arg,
                        RevGradient(
                            RevGradient::Flavor::Simple,
                            arg,
                            revParam,
                            nullptr));
                }
                else
                {
                    SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion");
                }
            }
        }

        // Move pointer & reference insts to the top of the reverse-mode block.
        List<IRInst*> typeInsts;
        for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
        {
            // If the instruction is a variable allocation (or reverse-gradient pair reference), 
            // move to top.
            // TODO: This is hacky.. Need a more principled way to handle this 
            // (like primal inst hoisting)
            // 
            //if (as<IRVar>(child) || as<IRReverseGradientDiffPairRef>(child))
            //    nonValueInsts.add(child);
            
            // Slang doesn't support function values. So if we see a func-typed inst
            // it's proabably a reference to a function.
            // 
            switch (child->getOp())
            {
            /*
               TODO: need a better way to move specialize, lookupwitness, extractExistentialType/Value/Witness
               insts to a proper location that dominates all their use sites. Create copies of these insts
               when necessary.
                case kIROp_Specialize:
                case kIROp_LookupWitness:
                case kIROp_ExtractExistentialType:
                case kIROp_ExtractExistentialValue:
                case kIROp_ExtractExistentialWitnessTable:
            */
            case kIROp_ForwardDifferentiate:
            case kIROp_BackwardDifferentiate:
            case kIROp_BackwardDifferentiatePrimal:
            case kIROp_BackwardDifferentiatePropagate:
                typeInsts.add(child);
                break;
            }
        }

        for (auto inst : typeInsts)
        {
            inst->insertAtEnd(revBlock);
        }

        // Then, go backwards through the regular instructions, and transpose them into the new
        // rev block.
        // Note the 'reverse' traversal here.
        // 
        for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst())
        {
            if (as<IRDecoration>(child) || as<IRParam>(child))
                continue;
            if (as<IRType>(child))
                continue;

            if (isDifferentialInst(child))
                transposeInst(&builder, child);
        }

        // After processing the block's instructions, we 'flush' any remaining gradients 
        // in the assignments map.
        // For now, these are only function parameter gradients (or of the form IRLoad(IRParam))
        // TODO: We should be flushing *all* gradients accumulated in this block to some 
        // function scope variable, since control flow can affect what blocks contribute to
        // for a specific inst.
        // 
        List<IRLoad*> loads;
        for (const auto& [key, _] : gradientsMap)
        {
            if (auto load = as<IRLoad>(key))
                loads.add(load);
        }
        for(const auto& load : loads)
                accumulateGradientsForLoad(&builder, load);

        // Do the same thing with the phi parameters if the block.
        List<IRInst*> phiParamRevGradInsts;
        for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam())
        {
            if (isDifferentialInst(param))
            {
                // This param might be used outside this block.
                // If so, add/get an accumulator.
                // 
                if (isInstUsedOutsideParentBlock(param))
                {
                    auto accGradient = extractAccumulatorVarGradient(&builder, param);
                    addRevGradientForFwdInst(
                        param, 
                        RevGradient(param, accGradient, nullptr));
                }
                if (hasRevGradients(param))
                {
                    auto gradients = popRevGradients(param);

                    auto gradInst = emitAggregateValue(
                        &builder,
                        tryGetPrimalTypeFromDiffInst(param),
                        gradients);
                    
                    phiParamRevGradInsts.add(gradInst);
                }
                else
                { 
                    phiParamRevGradInsts.add(
                        emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
                }
            }
            else
            { 
                SLANG_UNEXPECTED("param is neither differential inst nor marked for inversion");
            }
        }

        // Also handle any remaining gradients for insts that appear in prior blocks.
        List<IRInst*> externInsts; // Holds insts in a different block, same function.
        List<IRInst*> globalInsts; // Holds insts in the global scope.
        for (const auto& [inst, _] : gradientsMap)
        {
            auto instParent = inst->getParent();
            if (instParent != fwdBlock)
            {
                if (instParent->getParent() == fwdBlock->getParent())
                    externInsts.add(inst);
                
                if (as<IRModuleInst>(instParent))
                    globalInsts.add(inst);
            }
        }

        for (auto externInst : externInsts)
        {
            if (isNoDiffType(externInst->getDataType()))
            {
                popRevGradients(externInst);
                continue;
            }

            auto primalType = tryGetPrimalTypeFromDiffInst(externInst);
            SLANG_ASSERT(primalType);

            if (auto accVar = getOrCreateAccumulatorVar(externInst))
            {
                // Accumulate all gradients, including our accumulator variable,
                // into one inst.
                //
                auto gradients = popRevGradients(externInst);
                gradients.add(RevGradient(externInst, builder.emitLoad(accVar), nullptr));

                auto gradInst = emitAggregateValue(
                    &builder,
                    primalType,
                    gradients);
                
                builder.emitStore(accVar, gradInst);
            }
        }

        // For now, we're not going to handle global insts, and simply ignore them
        // Eventually, we want to turn these into global writes.
        // 
        for (auto globalInst : globalInsts)
        {
            if (hasRevGradients(globalInst))
                popRevGradients(globalInst);
        }

        // We _should_ be completely out of gradients to process at this point.
        SLANG_ASSERT(gradientsMap.getCount() == 0);

        // Record any phi gradients for the CFG reversal pass.
        phiGradsMap[fwdBlock] = phiParamRevGradInsts;

    }

    void transposeInst(IRBuilder* builder, IRInst* inst)
    {
        switch (inst->getOp())
        {
        case kIROp_ForwardDifferentiate:
            return;
        default:
            break;
        }

        // Look for gradient entries for this inst.
        List<RevGradient> gradients;
        if (hasRevGradients(inst))
        {
            gradients = popRevGradients(inst);
        }

        IRType* primalType = tryGetPrimalTypeFromDiffInst(inst);

        if (!primalType)
        {
            // Special-case instructions.
            if (auto returnInst = as<IRReturn>(inst))
            {
                auto returnPairType = as<IRDifferentialPairType>(
                    tryGetPrimalTypeFromDiffInst(returnInst->getVal()));
                if (!returnPairType)
                    return;
                primalType = returnPairType->getValueType();
            }
            else if (auto loadInst = as<IRLoad>(inst))
            {
                // TODO: Unzip loads properly to avoid having to side-step this check for IRLoad
                if (auto pairType = as<IRDifferentialPairType>(loadInst->getDataType()))
                {
                    primalType = pairType->getValueType();
                }
            }
        }

        if (!primalType)
        {
            // Check for special insts for which a reverse-mode gradient doesn't apply.
            if(!as<IRStore>(inst) && !as<IRTerminatorInst>(inst))
            {
                SLANG_UNEXPECTED("Could not resolve primal type for diff inst");
            }

            // If we still can't resolve a differential type, there shouldn't 
            // be any gradients to aggregate.
            // 
            SLANG_ASSERT(gradients.getCount() == 0);
        }

        // Is this inst used in another differential block?
        // Emit a function-scope accumulator variable, and include it's value.
        // Also, we ignore this if it's a load since those are turned into stores
        // on a per-block basis. (We should change this behaviour to treat loads like
        // any other inst)
        // 
        if (isInstUsedOutsideParentBlock(inst) && !as<IRLoad>(inst))
        {
            auto accGradient = extractAccumulatorVarGradient(builder, inst);
            gradients.add(
                RevGradient(inst, accGradient, nullptr));
        }
        
        // Emit the aggregate of all the gradients here. 
        // This will form the total derivative for this inst.
        auto revValue = emitAggregateValue(builder, primalType, gradients);

        auto transposeResult = transposeInst(builder, inst, revValue);
        
        if (auto fwdNameHint = inst->findDecoration<IRNameHintDecoration>())
        {
            StringBuilder sb;
            sb << fwdNameHint->getName() << "_T";
            builder->addNameHintDecoration(revValue, sb.getUnownedSlice());
        }
        
        // Add the new results to the gradients map.
        for (auto gradient : transposeResult.revPairs)
        {
            addRevGradientForFwdInst(gradient.targetInst, gradient);
        }
    }

    TranspositionResult transposeCall(IRBuilder* builder, IRCall* fwdCall, IRInst* revValue)
    {
        auto fwdDiffCallee = as<IRForwardDifferentiate>(fwdCall->getCallee());

        // If the callee is not a fwd-differentiate(fn), then there's only two
        // cases. This is a call to something that doesn't need to be transposed
        // or this is a user-written function calling something that isn't marked
        // with IRForwardDifferentiate, but is handling differentials. 
        // We currently do not handle the latter.
        // However, if we see a callee with no parameters, we can just skip over.
        // since there's nothing to backpropagate to.
        // 
        if (!fwdDiffCallee)
        {
            if (fwdCall->getArgCount() == 0)
            {
                return TranspositionResult(List<RevGradient>());
            }
            else
            {
                SLANG_UNIMPLEMENTED_X(
                    "This case should only trigger on a user-defined fwd-mode function"
                    " calling another user-defined function not marked with __fwd_diff()");
            }
        }

        auto baseFn = fwdDiffCallee->getBaseFn();

        List<IRInst*> args;
        List<IRType*> argTypes;
        List<bool> argRequiresLoad;

        auto getDiffPairType = [](IRType* type)
        {
            if (auto ptrType = as<IRPtrTypeBase>(type))
                type = ptrType->getValueType();
            return as<IRDifferentialPairType>(type);
        };

        struct DiffValWriteBack
        {
            IRInst* destVar;
            IRInst* srcTempPairVar;
        };
        List<DiffValWriteBack> writebacks;

        auto baseFnType = as<IRFuncType>(getResolvedInstForDecorations(baseFn->getDataType()));

        SLANG_RELEASE_ASSERT(baseFnType);
        SLANG_RELEASE_ASSERT(fwdCall->getArgCount() == baseFnType->getParamCount());

        for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++)
        {
            auto arg = fwdCall->getArg(ii);
            auto paramType = baseFnType->getParamType(ii);

            if (as<IRLoadReverseGradient>(arg))
            {
                // Original parameters that are `out DifferentiableType` will turn into
                // a `in Differential` parameter. The split logic will insert LoadReverseGradient insts
                // to inform us this case. Here we just need to generate a load of the derivative variable
                // and use it as the final argument.
                args.add(builder->emitLoad(arg->getOperand(0)));
                argTypes.add(args.getLast()->getDataType());
                argRequiresLoad.add(false);
            }
            else if (auto instPair = as<IRReverseGradientDiffPairRef>(arg))
            {
                // An argument to an inout parameter will come in the form of a ReverseGradientDiffPairRef(primalVar, diffVar) inst
                // after splitting.
                // In order to perform the call, we need a temporary var to store the DiffPair.
                auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType();
                auto tempVar = builder->emitVar(pairType);
                auto primalVal = builder->emitLoad(instPair->getPrimal());

                auto diffVal = builder->emitLoad(instPair->getDiff());
                auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
                builder->emitStore(tempVar, pairVal);
                args.add(tempVar);
                argTypes.add(builder->getInOutType(pairType));
                argRequiresLoad.add(false);
                writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
            }
            else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
            {
                // Normal differentiable input parameter will become an inout DiffPair parameter
                // in the propagate func. The split logic has already prepared the initial value
                // to pass in. We need to define a temp variable with this initial value and pass
                // in the temp variable as argument to the inout parameter. 

                auto makePairArg = as<IRMakeDifferentialPair>(arg);
                SLANG_RELEASE_ASSERT(makePairArg);

                auto pairType = as<IRDifferentialPairType>(arg->getDataType());
                auto var = builder->emitVar(arg->getDataType());
                
                auto diffZero = emitDZeroOfDiffInstType(builder, pairType->getValueType());

                // Initialize this var to (arg.primal, 0).
                builder->emitStore(
                    var,
                    builder->emitMakeDifferentialPair(
                        arg->getDataType(),
                        makePairArg->getPrimalValue(),
                        diffZero));

                args.add(var);
                argTypes.add(builder->getInOutType(pairType));
                argRequiresLoad.add(true);
            }
            else
            {
                if (as<IROutType>(paramType))
                {
                    args.add(nullptr);
                    argRequiresLoad.add(false);
                }
                else if (as<IRInOutType>(paramType))
                {
                    arg = builder->emitLoad(arg);
                    args.add(arg);
                    argTypes.add(arg->getDataType());
                    argRequiresLoad.add(false);
                }
                else
                {
                    args.add(arg);
                    argTypes.add(arg->getDataType());
                    argRequiresLoad.add(false);
                }
            }
        }

        if (revValue)
        {
            args.add(revValue);
            argTypes.add(revValue->getDataType());
            argRequiresLoad.add(false);
        }

        // If the callee provides a primal implementation that produces continuation context for propagation phase
        // we grab it and pass it as argument to the propagation function.
        //
        if (auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
        {   
            auto primalContextVar = primalContextDecor->getBackwardDerivativePrimalContextVar();
            
            auto contextLoad = builder->emitLoad(primalContextVar);

            args.add(contextLoad);
            argTypes.add(as<IRPtrTypeBase>(
                primalContextVar->getDataType())
                ->getValueType());
            argRequiresLoad.add(false);
        }

        auto revFnType = builder->getFuncType(argTypes, builder->getVoidType());
        IRInst* revCallee = nullptr;
        if (getResolvedInstForDecorations(baseFn)->getOp() == kIROp_LookupWitness)
        {
            // This is an interface method call, we can simply transcribe it here.
            auto specialize = as<IRSpecialize>(baseFn);
            auto innerFn = baseFn;
            if (specialize)
                innerFn = specialize->getBase();
            auto lookupWitness = as<IRLookupWitnessMethod>(innerFn);
            SLANG_RELEASE_ASSERT(lookupWitness);
            auto diffDecor = lookupWitness->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>();
            SLANG_RELEASE_ASSERT(diffDecor);
            auto diffKey = diffDecor->getBackwardDerivativeFunc();
            revCallee = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookupWitness->getWitnessTable(), diffKey);
            if (specialize)
            {
                List<IRInst*> specArgs;
                for (UInt i = 0; i < specialize->getArgCount(); i++)
                    specArgs.add(specialize->getArg(i));
                revCallee = builder->emitSpecializeInst(builder->getTypeKind(), revCallee, specArgs.getCount(), specArgs.getBuffer());
            }
            revCallee->setFullType(revFnType);
        }
        else
        {
            // All other calls, we insert a `backwardDifferentiate` inst so we will process it in a follow-up iteration.
            revCallee = builder->emitBackwardDifferentiatePropagateInst(
                revFnType,
                baseFn);
        }

        List<IRInst*> callArgs;
        for (auto arg : args)
            if (arg)
                callArgs.add(arg);
        builder->emitCallInst(revFnType->getResultType(), revCallee, callArgs);

        // Writeback result gradient to their corresponding splitted variable.
        for (auto wb : writebacks)
        {
            auto loadedPair = builder->emitLoad(wb.srcTempPairVar);
            auto diffType = as<IRPtrTypeBase>(wb.destVar->getDataType())->getValueType();
            auto loadedDiff = builder->emitDifferentialPairGetDifferential(diffType, loadedPair);
            builder->emitStore(wb.destVar, loadedDiff);
        }

        List<RevGradient> gradients;
        for (Index ii = 0; ii < args.getCount(); ii++)
        {
            if (!args[ii])
                continue;

            // Is this arg relevant to auto-diff?
            if (auto diffPairType = getDiffPairType(args[ii]->getDataType()))
            {
                // If this is ptr typed, ignore (the gradient will be accumulated on the pointer)
                // automatically.
                // 
                if (argRequiresLoad[ii])
                {
                    auto diffArgType = (IRType*)diffTypeContext.getDifferentialForType(
                        builder, 
                        diffPairType->getValueType());
                    gradients.add(RevGradient(
                        RevGradient::Flavor::Simple,
                        fwdCall->getArg(ii),
                        builder->emitDifferentialPairGetDifferential(
                            diffArgType, builder->emitLoad(args[ii])),
                        nullptr));
                }
            }
        }
        
        return TranspositionResult(gradients);
    }

    IRBlock* getPrimalBlock(IRBlock* fwdBlock)
    {
        if (auto fwdDiffDecoration = fwdBlock->findDecoration<IRDifferentialInstDecoration>())
        {
            return as<IRBlock>(fwdDiffDecoration->getPrimalInst());
        }

        return nullptr;
    }

    IRBlock* getFirstCodeBlock(IRGlobalValueWithCode* func)
    {
        return func->getFirstBlock()->getNextBlock();
    }

    List<IRBlock*> getTerminalPrimalBlocks(IRGlobalValueWithCode* func)
    {
        // 'Terminal' primal blocks are those that branch into a differential block.
        List<IRBlock*> terminalPrimalBlocks;
        for (auto block : func->getBlocks())
            for (auto successor : block->getSuccessors())
                if (!isDifferentialInst(block) && isDifferentialInst(successor))
                    terminalPrimalBlocks.add(block);

        return terminalPrimalBlocks;
    }

    IRBlock* getAfterBlock(IRBlock* block)
    {   
        auto terminatorInst = block->getTerminator();
        switch (terminatorInst->getOp())
        {
            case kIROp_unconditionalBranch:
            case kIROp_Return:
                return nullptr;

            case kIROp_ifElse:
                return as<IRIfElse>(terminatorInst)->getAfterBlock();
            case kIROp_Switch:
                return as<IRSwitch>(terminatorInst)->getBreakLabel();
            case kIROp_loop:
                return as<IRLoop>(terminatorInst)->getBreakBlock();
            
            default:
                SLANG_UNIMPLEMENTED_X("Unhandled terminator inst when building after-block map");
        }
    }

    void buildAfterBlockMap(IRGlobalValueWithCode* fwdFunc)
    {
        // Scan through a fwd-mode function, and build a list of blocks
        // that appear as the 'after' block for any conditional control
        // flow statement.
        //

        for (auto block = fwdFunc->getFirstBlock(); block; block = block->getNextBlock())
        {
            // Only need to process differential blocks.
            if (!isDifferentialInst(block))
                continue;

            IRBlock* afterBlock = getAfterBlock(block);

            if (afterBlock)
            {
                // No block can by the after block for multiple control flow insts.
                //
                SLANG_ASSERT(!(afterBlockMap.containsKey(afterBlock) && \
                    afterBlockMap[afterBlock] != block->getTerminator()));

                afterBlockMap.set(afterBlock, block->getTerminator());
            }
        }
    }

    List<IRBlock*> getTerminalDiffBlocks(IRGlobalValueWithCode* func)
    {
        // Terminal differential blocks are those with a return statement.
        // Note that this method is designed to work with Fwd-Mode blocks, 
        // and this logic will be different for Rev-Mode blocks.
        // 
        List<IRBlock*> terminalDiffBlocks;
        for (auto block : func->getBlocks())
            if (as<IRReturn>(block->getTerminator()))
                terminalDiffBlocks.add(block);

        return terminalDiffBlocks;
    }
    
    bool doesBlockHaveDifferentialPredecessors(IRBlock* fwdBlock)
    {
        for (auto block : fwdBlock->getPredecessors())
        {
            if (isDifferentialInst(block))
            {
                return true;
            }
        }

        return false;
    }

    IRBlock* insertPhiBlockBefore(IRBlock* revBlock, List<IRInst*> phiArgs)
    {
        IRBuilder phiBlockBuilder(autodiffContext->moduleInst->getModule());
        phiBlockBuilder.setInsertBefore(revBlock);

        auto phiBlock = phiBlockBuilder.emitBlock();

        if (isDifferentialInst(revBlock))
            phiBlockBuilder.markInstAsDifferential(phiBlock);
        
        phiBlockBuilder.emitBranch(
            revBlock,
            phiArgs.getCount(),
            phiArgs.getBuffer());
        
        return phiBlock;
    }
    
    TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        // Dispatch logic.
        switch(fwdInst->getOp())
        {
            case kIROp_Add:
            case kIROp_Mul:
            case kIROp_Sub: 
            case kIROp_Div: 
            case kIROp_Neg:
                return transposeArithmetic(builder, fwdInst, revValue);
            
            case kIROp_Select:
                return transposeSelect(builder, fwdInst, revValue);

            case kIROp_Call:
                return transposeCall(builder, as<IRCall>(fwdInst), revValue);
            
            case kIROp_swizzle:
                return transposeSwizzle(builder, as<IRSwizzle>(fwdInst), revValue);
            
            case kIROp_FieldExtract:
                return transposeFieldExtract(builder, as<IRFieldExtract>(fwdInst), revValue);

            case kIROp_GetElement:
                return transposeGetElement(builder, as<IRGetElement>(fwdInst), revValue);

            case kIROp_Return:
                return transposeReturn(builder, as<IRReturn>(fwdInst), revValue);
            
            case kIROp_Store:
                return transposeStore(builder, as<IRStore>(fwdInst), revValue);
            
            case kIROp_Load:
                return transposeLoad(builder, as<IRLoad>(fwdInst), revValue);

            case kIROp_MakeDifferentialPair:
                return transposeMakePair(builder, as<IRMakeDifferentialPair>(fwdInst), revValue);

            case kIROp_DifferentialPairGetDifferential:
                return transposeGetDifferential(builder, as<IRDifferentialPairGetDifferential>(fwdInst), revValue);

            case kIROp_MakeDifferentialPairUserCode:
                return transposeMakePairUserCode(builder, as<IRMakeDifferentialPairUserCode>(fwdInst), revValue);

            case kIROp_DifferentialPairGetPrimalUserCode:
                return transposeGetPrimalUserCode(builder, as<IRDifferentialPairGetPrimalUserCode>(fwdInst), revValue);

            case kIROp_DifferentialPairGetDifferentialUserCode:
                return transposeGetDifferentialUserCode(builder, as<IRDifferentialPairGetDifferentialUserCode>(fwdInst), revValue);

            case kIROp_MakeVector:
                return transposeMakeVector(builder, fwdInst, revValue);
            case kIROp_MakeVectorFromScalar:
                return transposeMakeVectorFromScalar(builder, fwdInst, revValue);
            case kIROp_MakeMatrixFromScalar:
                return transposeMakeMatrixFromScalar(builder, fwdInst, revValue);
            case kIROp_MakeMatrix:
                return transposeMakeMatrix(builder, fwdInst, revValue);
            case kIROp_MatrixReshape:
                return transposeMatrixReshape(builder, fwdInst, revValue);
            case kIROp_MakeStruct:
                return transposeMakeStruct(builder, fwdInst, revValue);
            case kIROp_MakeArray:
                return transposeMakeArray(builder, fwdInst, revValue);
            case kIROp_MakeArrayFromElement:
                return transposeMakeArrayFromElement(builder, fwdInst, revValue);

            case kIROp_UpdateElement:
                return transposeUpdateElement(builder, fwdInst, revValue);

            case kIROp_FloatCast:
                return transposeFloatCast(builder, fwdInst, revValue);

            case kIROp_MakeExistential:
                return transposeMakeExistential(builder, fwdInst, revValue);
            
            case kIROp_ExtractExistentialValue:
                return transposeExtractExistentialValue(builder, fwdInst, revValue);
            
            case kIROp_Reinterpret: 
                return transposeReinterpret(builder, fwdInst, revValue);
            
            case kIROp_PackAnyValue: 
                return transposePackAnyValue(builder, fwdInst, revValue);

            case kIROp_LoadReverseGradient:
            case kIROp_ReverseGradientDiffPairRef:
            case kIROp_DefaultConstruct:
            case kIROp_Specialize:
            case kIROp_unconditionalBranch:
            case kIROp_conditionalBranch:
            case kIROp_ifElse:
            case kIROp_loop:
            case kIROp_Switch:
            case kIROp_LookupWitness:
            case kIROp_ExtractExistentialType:
            case kIROp_ExtractExistentialWitnessTable:
            {
                // Ignore. transposeBlock() should take care of adding the
                // appropriate branch instruction.
                return TranspositionResult();
            }

            default:
                SLANG_ASSERT_FAILURE("Unhandled instruction");
        }
    }

    TranspositionResult transposeLoad(IRBuilder* builder, IRLoad* fwdLoad, IRInst* revValue)
    {
        auto revPtr = fwdLoad->getPtr();

        auto primalType = tryGetPrimalTypeFromDiffInst(fwdLoad);
        auto loadType = fwdLoad->getDataType();

        List<RevGradient> gradients(RevGradient(
            revPtr,
            revValue,
            nullptr));

        if (usedPtrs.contains(revPtr))
        {
            // Re-emit a load to get the _current_ value of revPtr.
            auto revCurrGrad = builder->emitLoad(revPtr);

            // Add the current value to the aggregation list.
            gradients.add(RevGradient(
                revPtr,
                revCurrGrad,
                nullptr));
        }
        else
        {
            usedPtrs.add(revPtr);
        }
        
        // Get the _total_ value.
        auto aggregateGradient = emitAggregateValue(
            builder,
            primalType,
            gradients);
        
        if (as<IRDifferentialPairType>(loadType))
        {
            auto primalPairVal = builder->emitLoad(revPtr);
            auto primalVal = builder->emitDifferentialPairGetPrimal(primalPairVal);

            auto pairVal = builder->emitMakeDifferentialPair(loadType, primalVal, aggregateGradient);

            builder->emitStore(revPtr, pairVal);
        }
        else
        {
            // Store this back into the pointer.
            builder->emitStore(revPtr, aggregateGradient);
        }

        return TranspositionResult(List<RevGradient>());
    }

    TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*)
    {
        IRInst* revVal = builder->emitLoad(fwdStore->getPtr());

        auto primalType = tryGetPrimalTypeFromDiffInst(fwdStore->getVal());
        SLANG_ASSERT(primalType);

        // Clear the value at the differential address, by setting to 0.
        IRInst* emptyVal = emitDZeroOfDiffInstType(builder, primalType);
        builder->emitStore(fwdStore->getPtr(), emptyVal);

        if (auto diffPairType = as<IRDifferentialPairType>(revVal->getDataType()))
        {
            revVal = builder->emitDifferentialPairGetDifferential(
                (IRType*)diffTypeContext.getDiffTypeFromPairType(
                    builder, diffPairType),
                revVal);
        }
        return TranspositionResult(
                    List<RevGradient>(
                        RevGradient(
                            RevGradient::Flavor::Simple,
                            fwdStore->getVal(),
                            revVal,
                            fwdStore)));
    }

    TranspositionResult transposeSwizzle(IRBuilder*, IRSwizzle* fwdSwizzle, IRInst* revValue)
    {
        // (A = p.x) -> (p = float3(dA, 0, 0))
        return TranspositionResult(
                    List<RevGradient>(
                        RevGradient(
                            RevGradient::Flavor::Swizzle,
                            fwdSwizzle->getBase(),
                            revValue,
                            fwdSwizzle)));
    }

    
    TranspositionResult transposeFieldExtract(IRBuilder*, IRFieldExtract* fwdExtract, IRInst* revValue)
    {
        return TranspositionResult(
                    List<RevGradient>(
                        RevGradient(
                            RevGradient::Flavor::FieldExtract,
                            fwdExtract->getBase(),
                            revValue,
                            fwdExtract)));
    }

    TranspositionResult transposeGetElement(IRBuilder*, IRGetElement* fwdGetElement, IRInst* revValue)
    {
        return TranspositionResult(
            List<RevGradient>(
                RevGradient(
                    RevGradient::Flavor::GetElement,
                    fwdGetElement->getBase(),
                    revValue,
                    fwdGetElement)));
    }

    TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue)
    {
        // Even though makePair returns a pair of (primal, differential)
        // revValue will only contain the reverse-value for 'differential'
        //
        // (P = (A, dA)) -> (dA += dP)
        //
        return TranspositionResult(
                    List<RevGradient>(
                        RevGradient(
                            RevGradient::Flavor::Simple,
                            fwdMakePair->getDifferentialValue(), 
                            revValue,
                            fwdMakePair)));
    }

    TranspositionResult transposeGetDifferential(IRBuilder*, IRDifferentialPairGetDifferential* fwdGetDiff, IRInst* revValue)
    {
        // (A = GetDiff(P)) -> (dP.d += dA)
        return TranspositionResult(
                    List<RevGradient>(
                        RevGradient(
                            RevGradient::Flavor::Simple,
                            fwdGetDiff->getBase(),
                            revValue,
                            fwdGetDiff)));
    }

    TranspositionResult transposeMakePairUserCode(IRBuilder* builder, IRMakeDifferentialPairUserCode* fwdMakePair, IRInst* revValue)
    {
        List<RevGradient> gradients;
        gradients.add(RevGradient(
            RevGradient::Flavor::Simple,
            fwdMakePair->getPrimalValue(),
            builder->emitDifferentialPairGetPrimalUserCode(revValue),
            fwdMakePair));
        gradients.add(RevGradient(
            RevGradient::Flavor::Simple,
            fwdMakePair->getDifferentialValue(),
            builder->emitDifferentialPairGetDifferentialUserCode(
                fwdMakePair->getDifferentialValue()->getFullType(), revValue),
            fwdMakePair));
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeGetDifferentialUserCode(IRBuilder*, IRDifferentialPairGetDifferentialUserCode* fwdGetDiff, IRInst* revValue)
    {
        // (A = x.p) -> (dX = DiffPairUserCode(dA, 0))
        return TranspositionResult(
            List<RevGradient>(
                RevGradient(
                    RevGradient::Flavor::DifferentialPairGetElementUserCode,
                    fwdGetDiff->getBase(),
                    revValue,
                    fwdGetDiff)));
    }

    TranspositionResult transposeGetPrimalUserCode(IRBuilder*, IRDifferentialPairGetPrimalUserCode* fwdGetPrimal, IRInst* revValue)
    {
        // (A = x.p) -> (dX = DiffPairUserCode(0, dA))
        return TranspositionResult(
            List<RevGradient>(
                RevGradient(
                    RevGradient::Flavor::DifferentialPairGetElementUserCode,
                    fwdGetPrimal->getBase(),
                    revValue,
                    fwdGetPrimal)));
    }

    TranspositionResult transposeMakeVectorFromScalar(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
    {
        auto vectorType = as<IRVectorType>(revValue->getDataType());
        SLANG_RELEASE_ASSERT(vectorType);
        auto vectorSize = as<IRIntLit>(vectorType->getElementCount());
        SLANG_RELEASE_ASSERT(vectorSize);

        List<RevGradient> gradients;
        for (UIndex ii = 0; ii < (UIndex)vectorSize->getValue(); ii++)
        {
            auto revComp = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), ii));
            gradients.add(RevGradient(
                            RevGradient::Flavor::Simple,
                            fwdMakeVector->getOperand(0),
                            revComp,
                            fwdMakeVector));
        }
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeMakeMatrixFromScalar(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue)
    {
        auto matrixType = as<IRMatrixType>(revValue->getDataType());
        SLANG_RELEASE_ASSERT(matrixType);
        auto row = as<IRIntLit>(matrixType->getRowCount());
        auto col = as<IRIntLit>(matrixType->getColumnCount());
        SLANG_RELEASE_ASSERT(row && col);

        List<RevGradient> gradients;
        for (UIndex r = 0; r < (UIndex)row->getValue(); r++)
        {
            for (UIndex c = 0; c < (UIndex)col->getValue(); c++)
            {
                auto revRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r));
                auto revCol = builder->emitElementExtract(revRow, builder->getIntValue(builder->getIntType(), c));
                gradients.add(RevGradient(
                    RevGradient::Flavor::Simple,
                    fwdMakeMatrix->getOperand(0),
                    revCol,
                    fwdMakeMatrix));
            }
        }
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeMakeMatrix(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue)
    {
        List<RevGradient> gradients;
        auto matrixType = as<IRMatrixType>(fwdMakeMatrix->getDataType());
        auto row = as<IRIntLit>(matrixType->getRowCount());
        auto colCount = matrixType->getColumnCount();
        IRType* rowVectorType = nullptr;
        for (UIndex ii = 0; ii < fwdMakeMatrix->getOperandCount(); ii++)
        {
            auto argOperand = fwdMakeMatrix->getOperand(ii);
            IRInst* gradAtIndex = nullptr;
            if (const auto vecType = as<IRVectorType>(argOperand->getDataType()))
            {
                gradAtIndex = builder->emitElementExtract(
                    argOperand->getDataType(),
                    revValue,
                    builder->getIntValue(builder->getIntType(), ii));
            }
            else
            {
                SLANG_RELEASE_ASSERT(row);
                UInt rowIndex = ii / (UInt)row->getValue();
                UInt colIndex = ii % (UInt)row->getValue();
                if (!rowVectorType)
                    rowVectorType = builder->getVectorType(matrixType->getElementType(), colCount);
                auto revRow = builder->emitElementExtract(
                    rowVectorType,
                    revValue,
                    builder->getIntValue(builder->getIntType(), rowIndex));
                gradAtIndex = builder->emitElementExtract(
                    matrixType->getElementType(),
                    revRow,
                    builder->getIntValue(builder->getIntType(), colIndex));
            }
            gradients.add(RevGradient(
                RevGradient::Flavor::Simple,
                fwdMakeMatrix->getOperand(ii),
                gradAtIndex,
                fwdMakeMatrix));
        }
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeMatrixReshape(IRBuilder* builder, IRInst* fwdMatrixReshape, IRInst* revValue)
    {
        List<RevGradient> gradients;
        auto operandMatrixType = as<IRMatrixType>(fwdMatrixReshape->getOperand(0)->getDataType());
        SLANG_RELEASE_ASSERT(operandMatrixType);

        auto operandRow = as<IRIntLit>(operandMatrixType->getRowCount());
        auto operandCol = as<IRIntLit>(operandMatrixType->getColumnCount());
        SLANG_RELEASE_ASSERT(operandRow && operandCol);

        auto revMatrixType = as<IRMatrixType>(revValue->getDataType());
        SLANG_RELEASE_ASSERT(revMatrixType);
        auto revRow = as<IRIntLit>(revMatrixType->getRowCount());
        auto revCol = as<IRIntLit>(revMatrixType->getColumnCount());
        SLANG_RELEASE_ASSERT(revRow && revCol);

        IRInst* dzero = nullptr;
        List<IRInst*> elements;
        for (IRIntegerValue r = 0; r < operandRow->getValue(); r++)
        {
            IRInst* dstRow = nullptr;
            if (r < revRow->getValue())
                dstRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r));
            for (IRIntegerValue c = 0; c < operandCol->getValue(); c++)
            {
                IRInst* element = nullptr;
                if (r < revRow->getValue() && c < revCol->getValue())
                {
                    element = builder->emitElementExtract(dstRow, builder->getIntValue(builder->getIntType(), c));
                }
                else
                {
                    if (!dzero)
                    {
                        dzero = builder->getFloatValue(operandMatrixType->getElementType(), 0.0f);
                    }
                    element = dzero;
                }
                elements.add(element);
            }
        }
        auto gradToProp = builder->emitMakeMatrix(operandMatrixType, (UInt)elements.getCount(), elements.getBuffer());
        gradients.add(RevGradient(
            RevGradient::Flavor::Simple,
            fwdMatrixReshape->getOperand(0),
            gradToProp,
            fwdMatrixReshape));
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
    {
        List<RevGradient> gradients;
        UInt offset = 0;
        for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++)
        {
            auto argOperand = fwdMakeVector->getOperand(ii);
            UInt componentCount = 1;
            if (auto vecType = as<IRVectorType>(argOperand->getDataType()))
            {
                auto intConstant = as<IRIntLit>(vecType->getElementCount());
                SLANG_RELEASE_ASSERT(intConstant);
                componentCount = (UInt)intConstant->getValue();
            }
            IRInst* gradAtIndex = nullptr;
            if (componentCount == 1)
            {
                gradAtIndex = builder->emitElementExtract(
                    argOperand->getDataType(),
                    revValue,
                    builder->getIntValue(builder->getIntType(), offset));
            }
            else
            {
                ShortList<UInt> componentIndices;
                for (UInt index = offset; index < offset + componentCount; index++)
                    componentIndices.add(index);
                gradAtIndex = builder->emitSwizzle(
                    argOperand->getDataType(),
                    revValue,
                    componentCount,
                    componentIndices.getArrayView().getBuffer());
            }

            gradients.add(RevGradient(
                RevGradient::Flavor::Simple,
                fwdMakeVector->getOperand(ii),
                gradAtIndex,
                fwdMakeVector));
            
            offset += componentCount;
        }

        // (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)]
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeMakeStruct(IRBuilder* builder, IRInst* fwdMakeStruct, IRInst* revValue)
    {
        List<RevGradient> gradients;
        auto structType = cast<IRStructType>(fwdMakeStruct->getFullType());
        UInt ii = 0;
        for (auto field : structType->getFields())
        {
            auto gradAtField = builder->emitFieldExtract(
                field->getFieldType(),
                revValue,
                field->getKey());
            SLANG_RELEASE_ASSERT(ii < fwdMakeStruct->getOperandCount());
            gradients.add(RevGradient(
                RevGradient::Flavor::Simple,
                fwdMakeStruct->getOperand(ii),
                gradAtField,
                fwdMakeStruct));
            ii++;
        }

        // (A = MakeStruct(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeMakeArray(IRBuilder* builder, IRInst* fwdMakeArray, IRInst* revValue)
    {
        List<RevGradient> gradients;
        auto arrayType = cast<IRArrayType>(fwdMakeArray->getFullType());

        for (UInt ii = 0; ii < fwdMakeArray->getOperandCount(); ii++)
        {
            auto gradAtField = builder->emitElementExtract(
                arrayType->getElementType(),
                revValue,
                builder->getIntValue(builder->getIntType(), ii));
            gradients.add(RevGradient(
                RevGradient::Flavor::Simple,
                fwdMakeArray->getOperand(ii),
                gradAtField,
                fwdMakeArray));
        }

        // (A = MakeArray(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeMakeArrayFromElement(IRBuilder* builder, IRInst* fwdMakeArrayFromElement, IRInst* revValue)
    {
        List<RevGradient> gradients;
        auto arrayType = cast<IRArrayType>(fwdMakeArrayFromElement->getFullType());
        auto arraySize = cast<IRIntLit>(arrayType->getElementCount());
        SLANG_RELEASE_ASSERT(arraySize);
        // TODO: if arraySize is a generic value, we can't statically expand things here.
        // In that case we probably need another opcode e.g. `Sum(arrayValue)` that can be expand
        // later in the pipeline when `arrayValue` becomes a known value.
        for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++)
        {
            auto gradAtField = builder->emitElementExtract(
                arrayType->getElementType(),
                revValue,
                builder->getIntValue(builder->getIntType(), ii));
            gradients.add(RevGradient(
                RevGradient::Flavor::Simple,
                fwdMakeArrayFromElement->getOperand(0),
                gradAtField,
                fwdMakeArrayFromElement));
        }

        // (A = MakeArrayFromElement(E)) -> [(dE += dA.F1), (dE += dA.F2), (dE += dA.F3)]
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeUpdateElement(IRBuilder* builder, IRInst* fwdUpdate, IRInst* revValue)
    {
        auto updateInst = as<IRUpdateElement>(fwdUpdate);

        List<RevGradient> gradients;
        auto accessChain = updateInst->getAccessChain();
        auto revElement = builder->emitElementExtract(revValue, accessChain.getArrayView());
        gradients.add(RevGradient(
            RevGradient::Flavor::Simple,
            updateInst->getElementValue(),
            revElement,
            fwdUpdate));

        auto primalElementTypeDecor = updateInst->findDecoration<IRPrimalElementTypeDecoration>();
        SLANG_RELEASE_ASSERT(primalElementTypeDecor);

        auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementTypeDecor->getPrimalElementType());
        SLANG_ASSERT(diffZero);
        auto revRest = builder->emitUpdateElement(
            revValue,
            accessChain,
            diffZero);
        gradients.add(RevGradient(
            RevGradient::Flavor::Simple,
            updateInst->getOldValue(),
            revRest,
            fwdUpdate));
        // (A = UpdateElement(arr, index, V)) -> [(dV += dA[index], d_arr += UpdateElement(revValue, index, 0)]
        return TranspositionResult(gradients);
    }

    TranspositionResult transposeFloatCast(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        // (A = cast<T, U>(B)) -> (dB += cast<U, T>(dA))
        return TranspositionResult(
            List<RevGradient>(
                RevGradient(
                    RevGradient::Flavor::Simple,
                    fwdInst->getOperand(0),
                    builder->emitIntrinsicInst(
                        fwdInst->getOperand(0)->getDataType(),
                        kIROp_FloatCast,
                        1,
                        &revValue),
                    fwdInst)));
    }

    TranspositionResult transposeMakeExistential(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        auto isExistentialType = [&](IRInst* type) -> bool
        {
            switch (type->getOp())
            {
            case kIROp_ExtractExistentialType:
            case kIROp_LookupWitness:
                return true;
            default:
                return false;
            }
        };

        auto diffType = fwdInst->getOperand(0)->getDataType();
        if (isExistentialType(diffType))
        {
            // (A:IDiff = MakeExistential(B, W)) -> (dB: T += ExtractExistentialValue(dW))
            return TranspositionResult(
                List<RevGradient>(
                    RevGradient(
                        RevGradient::Flavor::Simple,
                        fwdInst->getOperand(0),
                        builder->emitExtractExistentialValue(
                            fwdInst->getOperand(0)->getDataType(),
                            revValue),
                        fwdInst)));
        }
        else
        {
            // We have a concrete type.
            // (A:IDiff = MakeExistential(B, W)) -> 
            // (dB: T += ExtractExistentialValue(Reinterpret(dW)))
            auto diffValInDiffType = builder->emitReinterpret(
                diffType,
                builder->emitExtractExistentialValue(
                    builder->emitExtractExistentialType(revValue),
                    revValue));

            return TranspositionResult(
                List<RevGradient>(
                    RevGradient(
                        RevGradient::Flavor::Simple,
                        fwdInst->getOperand(0),
                        diffValInDiffType,
                        fwdInst)));
        }
    }

    TranspositionResult transposeExtractExistentialValue(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst);
        SLANG_ASSERT(primalType);

        // If we reach this point, revValue must be a differentiable type.
        auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness(
            builder,
            primalType);
        SLANG_ASSERT(revTypeWitness);

        auto baseExistential = fwdInst->getOperand(0);

        // (dA = ExtractExistentialValue(dB)) -> (dB += MakeExistential(T, A, ExtractExistentialWitness(B)))
        return TranspositionResult(
            List<RevGradient>(
                RevGradient(
                    RevGradient::Flavor::Simple,
                    baseExistential,
                    builder->emitMakeExistential(
                        baseExistential->getDataType(),
                        revValue,
                        revTypeWitness),
                    fwdInst)));
    }

    TranspositionResult transposeReinterpret(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        // (A = reinterpret<T, U>(B)) -> (dB += reinterpret<U, T>(dA))
        return TranspositionResult(
            List<RevGradient>(
                RevGradient(
                    RevGradient::Flavor::Simple,
                    fwdInst->getOperand(0),
                    builder->emitReinterpret(
                        fwdInst->getOperand(0)->getDataType(),
                        revValue),
                    fwdInst)));
    }

    
    TranspositionResult transposePackAnyValue(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        // (A = packAnyValue<T, U>(B)) -> (dB += unpackAnyValue<U, T>(dA))
        return TranspositionResult(
            List<RevGradient>(
                RevGradient(
                    RevGradient::Flavor::Simple,
                    fwdInst->getOperand(0),
                    builder->emitUnpackAnyValue(
                        fwdInst->getOperand(0)->getDataType(),
                        revValue),
                    fwdInst)));
    }

    // Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
    // 
    void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
    {
        return transposeInst(builder, revLoad);
    }

    TranspositionResult transposeReturn(IRBuilder*, IRReturn* fwdReturn, IRInst* revValue)
    {
        // TODO: This check needs to be changed to something like: isRelevantDifferentialPair()
        if (as<IRDifferentialPairType>(fwdReturn->getVal()->getDataType()))
        {
            // Simply pass on the gradient to the previous inst.
            // (Even if the return value is pair typed, we only care about the differential part)
            // So this will remain a 'simple' gradient.
            // 
            return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                RevGradient::Flavor::Simple,
                                fwdReturn->getVal(), 
                                revValue,
                                fwdReturn)));
        }
        else
        {
            // (return A) -> (empty)
            return TranspositionResult();
        }
    }

    IRInst* promoteToType(IRBuilder* builder, IRType* targetType, IRInst* inst)
    {
        auto currentType = inst->getDataType();

        switch (targetType->getOp())
        {

        case kIROp_VectorType:
        {
            // current type should be a scalar.
            SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType()));
            
            return builder->emitMakeVectorFromScalar(targetType, inst);
        }

        case kIROp_MatrixType:
        {
            // current type should be a scalar.
            SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType()) && 
                !as<IRMatrixType>(currentType->getDataType()));
            
            return builder->emitMakeMatrixFromScalar(targetType, inst);
        }
        
        default:
            // Default is not to promote.
            return inst;
        }
    }

    void safeSetInsertAfterInst(IRBuilder* builder, IRInst* inst)
    {
        // If the inst is in the first or second block of the parent function, then 
        // insert into the third block, otherwise simply call setInsertAfterOrdinaryInst.
        // The second block is the block that the first block branches into unconditionaly.
        //
        if (auto block = as<IRBlock>(inst->getParent()))
        {
            auto firstBlock = cast<IRFunc>(block->getParent())->getFirstBlock();
            if (auto firstBranch = as<IRUnconditionalBranch>(firstBlock->getTerminator()))
            {
                auto secondBlock = firstBranch->getTargetBlock();
                
                if (block == firstBlock || block == secondBlock)
                {
                    if (auto branch = as<IRUnconditionalBranch>(secondBlock->getTerminator()))
                    {
                        if (auto ordInst = branch->getTargetBlock()->getFirstOrdinaryInst())
                            builder->setInsertAfter(ordInst);
                        else
                            builder->setInsertInto(branch->getTargetBlock());

                        return;
                    }
                }
            }
        }
        setInsertAfterOrdinaryInst(builder, inst);
    }

    IRInst* promoteOperandsToTargetType(IRBuilder* builder, IRInst* fwdInst)
    {
        auto oldLoc = builder->getInsertLoc();
        // If operands are not of the same type, cast them to the target type.
        IRType* targetType = fwdInst->getDataType();

        bool needNewInst = false;
        
        List<IRInst*> newOperands;
        for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++)
        {
            auto operand = fwdInst->getOperand(ii);
            auto operandType = unwrapAttributedType(operand->getDataType());
            if (operandType != targetType)
            {
                // Insert new operand just after the old operand, so we have the old
                // operands available.
                // 
                safeSetInsertAfterInst(builder, operand);

                IRInst* newOperand = promoteToType(builder, targetType, operand);
                
                if (isDifferentialInst(operand))
                    builder->markInstAsDifferential(
                        newOperand, tryGetPrimalTypeFromDiffInst(fwdInst));

                newOperands.add(newOperand);

                needNewInst = true;
            }
            else
            {
                newOperands.add(operand);
            }
        }

        if(needNewInst)
        {
            builder->setInsertAfter(fwdInst);
            IRInst* newInst = builder->emitIntrinsicInst(
                fwdInst->getDataType(),
                fwdInst->getOp(),
                newOperands.getCount(),
                newOperands.getBuffer());
            
            builder->setInsertLoc(oldLoc);

            if (isDifferentialInst(fwdInst))
                builder->markInstAsDifferential(
                    newInst, tryGetPrimalTypeFromDiffInst(fwdInst));

            return newInst;
        }
        else
        {
            builder->setInsertLoc(oldLoc);
            return fwdInst;
        }
    }

    
    TranspositionResult transposeSelect(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        auto primalCondition = fwdInst->getOperand(0);

        auto leftZero = emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(1)));
        auto leftGradientInst = builder->emitIntrinsicInst(
            fwdInst->getOperand(1)->getDataType(),
            kIROp_Select,
            3,
            List<IRInst*>(primalCondition, revValue, leftZero).getBuffer());
        
        auto rightZero = emitDZeroOfDiffInstType(builder, tryGetPrimalTypeFromDiffInst(fwdInst->getOperand(2)));
        auto rightGradientInst = builder->emitIntrinsicInst(
            fwdInst->getOperand(2)->getDataType(),
            kIROp_Select,
            3,
            List<IRInst*>(primalCondition, rightZero, revValue).getBuffer());
        
        return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                fwdInst->getOperand(1),
                                leftGradientInst,
                                fwdInst),
                            RevGradient(
                                fwdInst->getOperand(2),
                                rightGradientInst,
                                fwdInst)));
    }

    TranspositionResult transposeArithmetic(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
    {
        
        // Only handle arithmetic on uniform types. If the types aren't uniform, we need some
        // promotion/demotion logic. Note that this can create a new inst in place of the old, but since we're
        // at the transposition step for the old inst, and already have it's aggregate gradient, there's
        // no need to worry about the 'gradientsMap' being out-of-date
        // TODO: There are some opportunities for optimization here (otherwise we might be increasing the intermediate
        // data size unnecessarily)
        // 
        fwdInst = promoteOperandsToTargetType(builder, fwdInst);

        auto operandType = fwdInst->getOperand(0)->getDataType();

        switch(fwdInst->getOp())
        {
            case kIROp_Add:
            {
                // (Out = dA + dB) -> [(dA += dOut), (dB += dOut)]
                return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                fwdInst->getOperand(0),
                                revValue,
                                fwdInst),
                            RevGradient(
                                fwdInst->getOperand(1),
                                revValue,
                                fwdInst)));
            }
            case kIROp_Sub:
            {
                // (Out = dA - dB) -> [(dA += dOut), (dB -= dOut)]
                return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                fwdInst->getOperand(0),
                                revValue,
                                fwdInst),
                            RevGradient(
                                fwdInst->getOperand(1),
                                builder->emitNeg(
                                    revValue->getDataType(), revValue),
                                fwdInst)));
            }
            case kIROp_Mul: 
            {
                if (isDifferentialInst(fwdInst->getOperand(0)))
                {
                    // (Out = dA * B) -> (dA += B * dOut)
                    return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                fwdInst->getOperand(0),
                                builder->emitMul(operandType, fwdInst->getOperand(1), revValue),
                                fwdInst)));
                }
                else if (isDifferentialInst(fwdInst->getOperand(1)))
                {
                    // (Out = A * dB) -> (dB += A * dOut)
                    return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                fwdInst->getOperand(1),
                                builder->emitMul(operandType, fwdInst->getOperand(0), revValue),
                                fwdInst)));
                }
                else
                {
                    SLANG_ASSERT_FAILURE("Neither operand of a mul instruction is a differential inst");
                }
            }   
            case kIROp_Div: 
            {
                if (isDifferentialInst(fwdInst->getOperand(0)))
                {
                    SLANG_RELEASE_ASSERT(!isDifferentialInst(fwdInst->getOperand(1)));

                    // (Out = dA / B) -> (dA += dOut / B)
                    return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                fwdInst->getOperand(0),
                                builder->emitDiv(operandType, revValue, fwdInst->getOperand(1)),
                                fwdInst)));
                }
                {
                    SLANG_ASSERT_FAILURE("The first operand of a div inst must be a differential inst");
                }
            }
            case kIROp_Neg: 
            {
                if (isDifferentialInst(fwdInst->getOperand(0)))
                {
                    // (Out = -dA) -> (dA += -dOut)
                    return TranspositionResult(
                        List<RevGradient>(
                            RevGradient(
                                fwdInst->getOperand(0),
                                builder->emitNeg(operandType, revValue),
                                fwdInst)));
                }
                else
                {
                    SLANG_ASSERT_FAILURE("Cannot transpose neg of a non-differentiable inst");
                }
            }   

            default:
                SLANG_ASSERT_FAILURE("Unhandled arithmetic");
        }
    }

    RevGradient materializeSwizzleGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
    {
        List<RevGradient> simpleGradients;

        SLANG_ASSERT(gradients.getCount() > 0);

        auto firstGradient = gradients[0];
        auto firstFwdSwizzleInst = as<IRSwizzle>(firstGradient.fwdGradInst);
        SLANG_ASSERT(firstFwdSwizzleInst);

        auto baseType = firstFwdSwizzleInst->getBase()->getDataType();

        IRIntegerValue elementCount = 0;
        IRType* elementType = nullptr;
        IRType* primalElementType = nullptr;
        bool isVectorType = false;

        if (auto vectorType = as<IRVectorType>(baseType))
        {
            IRInst* elementCountInst = vectorType->getElementCount();
            elementType = vectorType->getElementType();
            primalElementType = as<IRVectorType>(aggPrimalType)->getElementType();
            SLANG_ASSERT(as<IRIntLit>(elementCountInst));
            elementCount = as<IRIntLit>(elementCountInst)->getValue();
            isVectorType = true;
        }
        else if (auto basicType = as<IRBasicType>(baseType))
        {
            elementType = basicType;
            primalElementType = aggPrimalType;
            elementCount = 1;
        }
        else
        {
            SLANG_UNREACHABLE("unknown operand type of swizzle.");
        }

        IRInst* targetInst = firstGradient.targetInst;

        // Make a list of zeros of the base type.
        auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementType);

        List<IRInst*> elementGrads;
        for (Index i = 0; i < elementCount; ++i)
            elementGrads.add(zeroElement);
        
        auto accGrad = [&](UIndex i, IRInst* grad)
        {
            if (elementGrads[i] == zeroElement)
                elementGrads[i] = grad;
            else
                elementGrads[i] = emitDAddOfDiffInstType(builder, primalElementType, elementGrads[i], grad);
        };

        for (auto gradient : gradients)
        {
            SLANG_ASSERT(gradient.targetInst == targetInst);

            auto fwdSwizzleInst = as<IRSwizzle>(gradient.fwdGradInst);
            SLANG_ASSERT(as<IRSwizzle>(gradient.fwdGradInst));
            SLANG_ASSERT(as<IRSwizzle>(gradient.fwdGradInst)->getBase() == firstFwdSwizzleInst->getBase());

            // Replace swizzled elements with their gradients.
            for (Index ii = 0; ii < ((Index)fwdSwizzleInst->getElementCount()); ii++)
            {
                auto sourceIndex = ii;
                auto targetIndexInst = fwdSwizzleInst->getElementIndex(ii);
                SLANG_ASSERT(as<IRIntLit>(targetIndexInst));
                auto targetIndex = as<IRIntLit>(targetIndexInst)->getValue();

                // Case 1: Swizzled output is a single element, 
                if (fwdSwizzleInst->getElementCount() == 1)
                    accGrad((UIndex)targetIndex, gradient.revGradInst);
                // Case 2: Swizzled output is a vector, so we need to extract the element.
                else if (isVectorType)
                    accGrad((UIndex)targetIndex,
                        builder->emitElementExtract(
                            elementType,
                            gradient.revGradInst,
                            builder->getIntValue(
                                builder->getIntType(),
                                sourceIndex)));
                // Case 3: Swizzled input is a scalar.
                else
                    accGrad((UIndex)targetIndex, gradient.revGradInst);
            }
        }

        if (isVectorType)
            return RevGradient(
                targetInst,
                builder->emitMakeVector(baseType, (UInt)elementCount, elementGrads.getBuffer()),
                nullptr);
        else
            return RevGradient(
                targetInst,
                elementGrads[0],
                nullptr);
    }

    RevGradient materializeDifferentialPairUserCodeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
    {
        List<RevGradient> simpleGradients;

        for (auto gradient : gradients)
        {
            // Peek at the fwd-mode get element inst to see what type we need to materialize.
            if (auto fwdGetDiff = as<IRDifferentialPairGetDifferentialUserCode>(gradient.fwdGradInst))
            {
                auto baseType = as<IRDifferentialPairUserCodeType>(diffTypeContext.getDifferentialForType(
                    builder,
                    fwdGetDiff->getBase()->getDataType()));
                simpleGradients.add(
                    RevGradient(
                        gradient.targetInst,
                        builder->emitMakeDifferentialPairUserCode(baseType, emitDZeroOfDiffInstType(builder, baseType->getValueType()), gradient.revGradInst),
                        gradient.fwdGradInst));
            }
            else if (auto fwdGetPrimal = as<IRDifferentialPairGetPrimalUserCode>(gradient.fwdGradInst))
            {
                auto baseType = as<IRDifferentialPairUserCodeType>(diffTypeContext.getDifferentialForType(
                    builder,
                    fwdGetPrimal->getBase()->getDataType()));
                simpleGradients.add(
                    RevGradient(
                        gradient.targetInst,
                        builder->emitMakeDifferentialPairUserCode(baseType, gradient.revGradInst, emitDZeroOfDiffInstType(builder, fwdGetPrimal->getFullType())),
                        gradient.fwdGradInst));
            }
        }

        return materializeSimpleGradients(builder, aggPrimalType, simpleGradients);
    }

    RevGradient materializeGradientSet(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
    {
        switch (gradients[0].flavor)
        {
            case RevGradient::Flavor::Simple:
                return materializeSimpleGradients(builder, aggPrimalType, gradients);
            
            case RevGradient::Flavor::Swizzle:
                return materializeSwizzleGradients(builder, aggPrimalType, gradients);

            case RevGradient::Flavor::FieldExtract:
                return materializeFieldExtractGradients(builder, aggPrimalType, gradients);

            case RevGradient::Flavor::GetElement:
                return materializeGetElementGradients(builder, aggPrimalType, gradients);

            case RevGradient::Flavor::DifferentialPairGetElementUserCode:
                return materializeDifferentialPairUserCodeGetElementGradients(builder, aggPrimalType, gradients);

            default:
                SLANG_ASSERT_FAILURE("Unhandled gradient flavor for materialization");
        }
    }

    RevGradient materializeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
    {
        // Setup a temporary variable to aggregate gradients.
        // TODO: We can extend this later to grab an existing ptr to allow aggregation of
        // gradients across blocks without constructing new variables.
        // Looking up an existing pointer could also allow chained accesses like x.a.b[1] to directly
        // write into the specific sub-field that is affected without constructing intermediate vars.
        // 
        auto revGradVar = builder->emitVar(
            (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType));

        // Initialize with T.dzero()
        auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType);

        builder->emitStore(revGradVar, zeroValueInst);

        OrderedDictionary<IRInst*, List<RevGradient>> bucketedGradients;
        for (auto gradient : gradients)
        {
            // Grab the element affected by this gradient.
            auto getElementInst = as<IRGetElement>(gradient.fwdGradInst);
            SLANG_ASSERT(getElementInst);

            auto index = getElementInst->getIndex();
            SLANG_ASSERT(index);

            if (!bucketedGradients.containsKey(index))
            {
                bucketedGradients[index] = List<RevGradient>();
            }

            bucketedGradients[index].getValue().add(RevGradient(
                RevGradient::Flavor::Simple,
                gradient.targetInst,
                gradient.revGradInst,
                gradient.fwdGradInst
            ));

        }

        for (auto pair : bucketedGradients)
        {
            auto subGrads = pair.value;

            auto primalType = tryGetPrimalTypeFromDiffInst(subGrads[0].fwdGradInst);

            SLANG_ASSERT(primalType);

            // Construct address to this field in revGradVar.
            auto revGradTargetAddress = builder->emitElementAddress(
                builder->getPtrType(subGrads[0].revGradInst->getDataType()),
                revGradVar,
                pair.key);

            builder->emitStore(revGradTargetAddress, emitAggregateValue(builder, primalType, subGrads));
        }

        // Load the entire var and return it.
        return RevGradient(
            RevGradient::Flavor::Simple,
            gradients[0].targetInst,
            builder->emitLoad(revGradVar),
            nullptr);
    }


    RevGradient materializeFieldExtractGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
    {
        // Setup a temporary variable to aggregate gradients.
        // TODO: We can extend this later to grab an existing ptr to allow aggregation of
        // gradients across blocks without constructing new variables.
        // Looking up an existing pointer could also allow chained accesses like x.a.b[1] to directly
        // write into the specific sub-field that is affected without constructing intermediate vars.
        // 
        auto revGradVar = builder->emitVar(
            (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType));

        // Initialize with T.dzero()
        auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType);

        builder->emitStore(revGradVar, zeroValueInst);

        OrderedDictionary<IRStructKey*, List<RevGradient>> bucketedGradients;
        for (auto gradient : gradients)
        {
            // Grab the field affected by this gradient.
            auto fieldExtractInst = as<IRFieldExtract>(gradient.fwdGradInst);
            SLANG_ASSERT(fieldExtractInst);

            auto structKey = as<IRStructKey>(fieldExtractInst->getField());
            SLANG_ASSERT(structKey);

            if (!bucketedGradients.containsKey(structKey))
            {
                bucketedGradients[structKey] = List<RevGradient>();
            }
            
            bucketedGradients[structKey].getValue().add(RevGradient(
                RevGradient::Flavor::Simple,
                gradient.targetInst,
                gradient.revGradInst,
                gradient.fwdGradInst
            ));

        }

        for (auto pair : bucketedGradients)
        {
            auto subGrads = pair.value;

            auto primalType = tryGetPrimalTypeFromDiffInst(subGrads[0].fwdGradInst);

            SLANG_ASSERT(primalType);
    
            // Construct address to this field in revGradVar.
            auto revGradTargetAddress = builder->emitFieldAddress(
                builder->getPtrType(subGrads[0].revGradInst->getDataType()),
                revGradVar,
                pair.key);

            builder->emitStore(revGradTargetAddress, emitAggregateValue(builder, primalType, subGrads));
        }
            
        // Load the entire var and return it.
        return RevGradient(
            RevGradient::Flavor::Simple,
            gradients[0].targetInst,
            builder->emitLoad(revGradVar),
            nullptr);
    }

    RevGradient materializeSimpleGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
    {
        if (gradients.getCount() == 1)
        {
            // If there's only one value to add up, just return it in order
            // to avoid a stack of 0 + 0 + 0 + ...
            return gradients[0];
        }

        // If there's more than one gradient, aggregate them by adding them up.
        IRInst* currentValue = nullptr;
        for (auto gradient : gradients)
        {
            if (!currentValue)
            {
                currentValue = gradient.revGradInst;
                continue;
            }

            currentValue = emitDAddOfDiffInstType(builder, aggPrimalType, currentValue, gradient.revGradInst);
        }

        return RevGradient(
                    RevGradient::Flavor::Simple,
                    gradients[0].targetInst,
                    currentValue,
                    nullptr);
    }

    IRInst* emitAggregateValue(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)
    {
        // If we're dealing with the differential-pair types, we need to use a different aggregation method, since
        // a differential pair is really a 'hybrid' primal-differential type.
        //
        if (as<IRDifferentialPairType>(aggPrimalType))
        {
            SLANG_UNEXPECTED("Should not occur");
        }

        // Process non-simple gradients into simple gradients.
        // TODO: This is where we can improve efficiency later.
        // For instance if we have one gradient each for var.x, var.y and var.z
        // we can construct one single gradient vector out of the three vectors (i.e. float3(x_grad, y_grad, z_grad))
        // instead of creating one vector for each gradient and accumulating them 
        // (i.e. float3(x_grad, 0, 0) + float3(0, y_grad, 0) + float3(0, 0, z_grad))
        // The same concept can be extended for struct and array types (and for any combination of the three)
        // 
        List<RevGradient> simpleGradients;
        {
            // Start by sorting gradients based on flavor.
            gradients.sort([&](const RevGradient& a, const RevGradient& b) -> bool { return a.flavor < b.flavor; });

            Index ii = 0;
            while (ii < gradients.getCount())
            {
                List<RevGradient> gradientsOfFlavor;

                RevGradient::Flavor currentFlavor = (gradients.getCount() > 0) ? gradients[ii].flavor : RevGradient::Flavor::Simple;

                // Pull all the gradients matching the flavor of the top-most gradeint into a temporary list.
                for (; ii < gradients.getCount(); ii++)
                {
                    if (gradients[ii].flavor == currentFlavor)
                    {
                        gradientsOfFlavor.add(gradients[ii]);
                    }
                    else
                    {
                        break;
                    }
                }

                // Turn the set into a simple gradient.
                auto simpleGradient = materializeGradientSet(builder, aggPrimalType, gradientsOfFlavor);
                SLANG_ASSERT(simpleGradient.flavor == RevGradient::Flavor::Simple);

                simpleGradients.add(simpleGradient);
            }
        }

        if (simpleGradients.getCount() == 0)
        {   
            // If there are no gradients to add up, check the type and emit a 0/null value.
            auto aggDiffType = (aggPrimalType) ? diffTypeContext.getDifferentialForType(builder, aggPrimalType) : nullptr;
            if (aggDiffType != nullptr)
            {
                // If type is non-null/non-void, call T.dzero() to produce a 0 gradient.
                return emitDZeroOfDiffInstType(builder, aggPrimalType);
            }
            else
            {
                // Otherwise, gradients may not be applicable for this inst. return N/A
                return nullptr;
            }
        }
        else 
        {
            return materializeSimpleGradients(builder, aggPrimalType, simpleGradients).revGradInst;
        }
    }

    IRType* tryGetPrimalTypeFromDiffInst(IRInst* diffInst)
    {
        // Look for differential inst decoration.
        if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>())
            return diffInstDecoration->getPrimalType();

        return nullptr;
    }

    IRInst* tryGetWitnessFromDiffInst(IRInst* diffInst)
    {
        // Look for differential inst decoration.
        if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>())
            return diffInstDecoration->getWitness();
        
        return nullptr;
    }

    IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType)
    {
        if (auto arrayType = as<IRArrayType>(primalType))
        {
            auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType());
            SLANG_RELEASE_ASSERT(diffElementType);
            auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount());
            auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType());
            return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero);
        }
        else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
        {
            auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType());
            auto diffZero = primalZero;
            auto diffType = primalZero->getFullType();
            auto diffWitness = diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType);
            auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
            return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero);
        }
        else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
        {
            // Pack a null value into an existential type.
            auto existentialZero = builder->emitMakeExistential(
                autodiffContext->differentiableInterfaceType,
                diffTypeContext.emitNullDifferential(builder),
                autodiffContext->nullDifferentialWitness);

            return existentialZero;
        }

        auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType);

        // Should exist.
        SLANG_ASSERT(zeroMethod);

        return builder->emitCallInst(
            (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
            zeroMethod,
            List<IRInst*>());
    }
    
    IRInst* emitDAddForExistentialType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2)
    {
        auto existentialDAddFunc = diffTypeContext.getOrCreateExistentialDAddMethod();

        // Should exist.
        SLANG_ASSERT(existentialDAddFunc);

        return builder->emitCallInst(
            (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
            existentialDAddFunc,
            List<IRInst*>({ op1, op2 }));
    }

    IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2)
    {
        if (auto arrayType = as<IRArrayType>(primalType))
        {
            auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType());
            SLANG_RELEASE_ASSERT(diffElementType);
            auto arraySize = arrayType->getElementCount();

            if (auto constArraySize = as<IRIntLit>(arraySize))
            {
                List<IRInst*> args;
                for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++)
                {
                    auto index = builder->getIntValue(builder->getIntType(), i);
                    auto op1Val = builder->emitElementExtract(diffElementType, op1, index);
                    auto op2Val = builder->emitElementExtract(diffElementType, op2, index);
                    args.add(emitDAddOfDiffInstType(builder, arrayType->getElementType(), op1Val, op2Val));
                }
                auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount());
                return builder->emitMakeArray(diffArrayType, (UInt)args.getCount(), args.getBuffer());
            }
            else
            {
                // TODO: insert a runtime loop here.
                SLANG_UNIMPLEMENTED_X("dadd of dynamic array.");
            }
        }
        else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
        {
            auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, diffPairUserType);
            auto diffWitness = diffTypeContext.getDiffTypeWitnessFromPairType(builder, diffPairUserType);

            auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1);
            auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2);
            auto primal = emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2);

            auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1);
            auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2);
            auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2);

            auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
            return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff);
        }
        else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
        {
            // If our type is existential, we need to handle the case where 
            // one or both of our operands are null-type. 
            // 
            return emitDAddForExistentialType(builder, primalType, op1, op2);
        }

        auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType);

        // Should exist.
        SLANG_ASSERT(addMethod);

        return builder->emitCallInst(
            (IRType*)diffTypeContext.getDifferentialForType(builder, primalType),
            addMethod,
            List<IRInst*>(op1, op2));
    }

    void addRevGradientForFwdInst(IRInst* fwdInst, RevGradient assignment)
    {
        if (!hasRevGradients(fwdInst))
        {
            gradientsMap[fwdInst] = List<RevGradient>();
        }
        gradientsMap.getValue(fwdInst).add(assignment);
    }

    List<RevGradient> getRevGradients(IRInst* fwdInst)
    {
        return gradientsMap[fwdInst];
    }

    List<RevGradient> popRevGradients(IRInst* fwdInst)
    {
        List<RevGradient> val = gradientsMap.getValue(fwdInst);
        gradientsMap.remove(fwdInst);
        return val;
    }

    bool hasRevGradients(IRInst* fwdInst)
    {
        return gradientsMap.containsKey(fwdInst);
    }

    AutoDiffSharedContext*                               autodiffContext;

    DifferentiableTypeConformanceContext                 diffTypeContext;

    DifferentialPairTypeBuilder                          pairBuilder;

    IRBlock*                                             tempInvBlock;

    Dictionary<IRInst*, List<RevGradient>>               gradientsMap;

    Dictionary<IRInst*, IRVar*>                          revAccumulatorVarMap;

    Dictionary<IRInst*, IRVar*>                          inverseVarMap;

    List<IRInst*>                                        usedPtrs;

    Dictionary<IRBlock*, IRBlock*>                       revBlockMap;

    Dictionary<IRGlobalValueWithCode*, IRBlock*>         firstRevDiffBlockMap;

    Dictionary<IRBlock*, IRInst*>                        afterBlockMap;

    List<PendingBlockTerminatorEntry>                    pendingBlocks;

    Dictionary<IRBlock*, List<IRInst*>>                  phiGradsMap;
    
    Dictionary<IRInst*, IRInst*>                         inverseValueMap;

    List<IRUse*>                                         primalUsesToHoist;

    Dictionary<IRStore*, IRBlock*>                       mapStoreToDefBlock;

    IRCloneEnv typeInstCloneEnv = {};

};


}
back to top