https://github.com/shader-slang/slang
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-autodiff-propagate.h
// slang-ir-autodiff-propagate.h
#pragma once
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-compiler.h"
#include "slang-ir-autodiff.h"
namespace Slang
{
inline bool isDifferentialInst(IRInst* inst)
{
return inst->findDecoration<IRDifferentialInstDecoration>();
}
inline bool isPrimalInst(IRInst* inst)
{
return inst->findDecoration<IRPrimalInstDecoration>() || (as<IRConstant>(inst) != nullptr);
}
inline bool isMixedDifferentialInst(IRInst* inst)
{
return inst->findDecoration<IRMixedDifferentialInstDecoration>();
}
struct DiffPropagationPass : InstPassBase
{
AutoDiffSharedContext* autodiffContext;
DiffPropagationPass(AutoDiffSharedContext* autodiffContext) :
autodiffContext(autodiffContext),
InstPassBase(autodiffContext->moduleInst->getModule())
{
}
bool shouldInstBeMarkedDifferential(IRInst* inst)
{
for (UIndex ii = 0; ii < inst->getOperandCount(); ii ++)
{
if (isDifferentialInst(inst->getOperand(ii)))
{
return true;
}
}
return false;
}
void addPendingUsersToWorkList(IRInst* inst)
{
auto use = inst->firstUse;
while (use)
{
if (!isDifferentialInst(use->getUser()))
{
addToWorkList(use->getUser());
}
use = use->nextUse;
}
}
// Propagate IRDifferentialInstDecoration for all children of instWithChildren.
void propagateDiffInstDecoration(IRBuilder* builder, IRInst* instWithChildren)
{
List<IRInst*> initialList;
// Mark 'GetDifferential' insts as differential.
processChildInstsOfType<IRDifferentialPairGetDifferential>(
kIROp_DifferentialPairGetDifferential,
instWithChildren,
[&](IRDifferentialPairGetDifferential* getDifferentialInst)
{
builder->markInstAsDifferential(getDifferentialInst);
initialList.add(getDifferentialInst);
});
workList.clear();
workListSet.clear();
// Add the marked insts to the work list.
for (auto inst : initialList)
{
// Look for insts marked as differential.
if (isDifferentialInst(inst))
addPendingUsersToWorkList(inst);
}
// Propagate to all users..
while (workList.getCount() != 0)
{
IRInst* inst = pop();
// Skip if this is already a differential inst.
if (isDifferentialInst(inst))
{
continue;
}
if (shouldInstBeMarkedDifferential(inst))
{
builder->markInstAsDifferential(inst);
addPendingUsersToWorkList(inst);
}
}
}
};
}