https://github.com/shader-slang/slang
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-addr-inst-elimination.cpp
#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
namespace Slang
{
// Rewrites address load/store into value extract/updates to allow SSA transform to apply to struct and array elements.
// For example,
// load(elementPtr(arr, 1)) ==> elementExtract(load(arr), 1)
// store(fieldAddr(s, field_key), val) ==> store(s, updateField(load(s), fieldKey, val))
// After this transform, all address operands of `load` and `store` insts will be either a var or a param.
struct AddressInstEliminationContext
{
IRModule* module;
DiagnosticSink* sink;
IRInst* getValue(IRBuilder& builder, IRInst* addr)
{
switch (addr->getOp())
{
default:
return builder.emitLoad(addr);
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
{
IRInst* args[] = {getValue(builder, addr->getOperand(0)), addr->getOperand(1)};
return builder.emitIntrinsicInst(
cast<IRPtrTypeBase>(addr->getFullType())->getValueType(),
(addr->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract),
2,
args);
}
}
}
void storeValue(IRBuilder& builder, IRInst* addr, IRInst* val)
{
List<IRInst*> accessChain;
for (auto inst = addr; inst;)
{
switch (inst->getOp())
{
default:
accessChain.add(inst);
goto endLoop;
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
accessChain.add(inst->getOperand(1));
inst = inst->getOperand(0);
break;
}
}
endLoop:;
auto lastAddr = accessChain.getLast();
accessChain.removeLast();
accessChain.reverse();
if (accessChain.getCount())
{
auto lastVal = builder.emitLoad(lastAddr);
auto update = builder.emitUpdateElement(lastVal, accessChain, val);
builder.emitStore(lastAddr, update);
}
else
{
builder.emitStore(lastAddr, val);
}
}
void transformLoadAddr(IRUse* use)
{
auto addr = use->get();
auto load = as<IRLoad>(use->getUser());
IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
auto value = getValue(builder, addr);
load->replaceUsesWith(value);
load->removeAndDeallocate();
}
void transformStoreAddr(IRUse* use)
{
auto addr = use->get();
auto store = as<IRStore>(use->getUser());
IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
storeValue(builder, addr, store->getVal());
store->removeAndDeallocate();
}
void transformCallAddr(IRUse* use)
{
auto addr = use->get();
auto call = as<IRCall>(use->getUser());
// Don't change the use if addr is a non mutable address.
if (as<IRConstRefType>(getRootAddr(addr)->getDataType()))
{
return;
}
IRBuilder builder(module);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());
// Store the initial value of the mutable argument into temp var.
// If this is an `out` var, the initial value will be undefined,
// which will get cleaned up later into a `defaultConstruct`.
builder.emitStore(tempVar, getValue(builder, addr));
builder.setInsertAfter(call);
storeValue(builder, addr, builder.emitLoad(tempVar));
use->set(tempVar);
}
SlangResult eliminateAddressInstsImpl(
IRFunc* func,
DiagnosticSink* inSink)
{
sink = inSink;
IRBuilder builder(func->getModule());
List<IRInst*> workList;
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
{
if (as<IRConstRefType>(getRootAddr(inst)->getDataType()))
continue;
if (auto ptrType = as<IRPtrTypeBase>(inst->getDataType()))
{
auto valType = unwrapAttributedType(ptrType->getValueType());
if (!getResolvedInstForDecorations(valType)->findDecoration<IRNonCopyableTypeDecoration>())
{
workList.add(inst);
}
}
}
}
for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++)
{
auto addrInst = workList[workListIndex];
for (auto use = addrInst->firstUse; use; )
{
auto nextUse = use->nextUse;
if (as<IRDecoration>(use->getUser()))
{
use = nextUse;
continue;
}
switch (use->getUser()->getOp())
{
case kIROp_Load:
transformLoadAddr(use);
break;
case kIROp_Store:
transformStoreAddr(use);
break;
case kIROp_Call:
transformCallAddr(use);
break;
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
break;
default:
sink->diagnose(use->getUser()->sourceLoc, Diagnostics::unsupportedUseOfLValueForAutoDiff);
break;
}
use = nextUse;
}
}
return SLANG_OK;
}
};
SlangResult eliminateAddressInsts(
IRFunc* func,
DiagnosticSink* sink)
{
AddressInstEliminationContext ctx;
ctx.module = func->getModule();
ctx.sink = sink;
return ctx.eliminateAddressInstsImpl(func, sink);
}
} // namespace Slang