https://github.com/shader-slang/slang
Raw File
Tip revision: d7ba60c993366b4aaf6ef8ee7d8eab940d61eac8 authored by Yong He on 03 April 2023, 03:43:09 UTC
Fix type legalization pass. (#2768)
Tip revision: d7ba60c
slang-ir-autodiff-propagate.h
// slang-ir-autodiff-propagate.h
#pragma once

#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-compiler.h"

#include "slang-ir-autodiff.h"

namespace Slang
{

inline bool isDifferentialInst(IRInst* inst)
{
    return inst->findDecoration<IRDifferentialInstDecoration>();
}

inline bool isPrimalInst(IRInst* inst)
{
    return inst->findDecoration<IRPrimalInstDecoration>() || (as<IRConstant>(inst) != nullptr);
}

inline bool isMixedDifferentialInst(IRInst* inst)
{
    return inst->findDecoration<IRMixedDifferentialInstDecoration>();
}

struct DiffPropagationPass : InstPassBase
{
    AutoDiffSharedContext*                  autodiffContext;

    DiffPropagationPass(AutoDiffSharedContext* autodiffContext) : 
        autodiffContext(autodiffContext),
        InstPassBase(autodiffContext->moduleInst->getModule())
    {

    }


    bool shouldInstBeMarkedDifferential(IRInst* inst)
    {
        for (UIndex ii = 0; ii < inst->getOperandCount(); ii ++)
        {
            if (isDifferentialInst(inst->getOperand(ii)))
            {
                return true;   
            }
        }

        return false;
    }

    void addPendingUsersToWorkList(IRInst* inst)
    {
        auto use = inst->firstUse;
        while (use)
        {
            if (!isDifferentialInst(use->getUser()))
            {
                addToWorkList(use->getUser());
            }
            use = use->nextUse;
        }
    }

    // Propagate IRDifferentialInstDecoration for all children of instWithChildren.
    void propagateDiffInstDecoration(IRBuilder* builder, IRInst* instWithChildren)
    {
        List<IRInst*> initialList;
        // Mark 'GetDifferential' insts as differential.
        processChildInstsOfType<IRDifferentialPairGetDifferential>(
            kIROp_DifferentialPairGetDifferential, 
            instWithChildren, 
            [&](IRDifferentialPairGetDifferential* getDifferentialInst)
            {
                builder->markInstAsDifferential(getDifferentialInst);
                initialList.add(getDifferentialInst);
            });


        workList.clear();
        workListSet.Clear();

        // Add the marked insts to the work list.
        for (auto inst : initialList)
        {
            // Look for insts marked as differential.
            if (isDifferentialInst(inst))
                addPendingUsersToWorkList(inst);
        }

        // Propagate to all users..
        while (workList.getCount() != 0)
        {
            IRInst* inst = pop();

            // Skip if this is already a differential inst.
            if (isDifferentialInst(inst))
            {
                continue;
            }

            if (shouldInstBeMarkedDifferential(inst))
            {
                builder->markInstAsDifferential(inst);
                addPendingUsersToWorkList(inst);
            }
        }
    }
};

}
back to top