https://github.com/shader-slang/slang
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-loop-unroll.cpp
#include "slang-ir-loop-unroll.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-peephole.h"
#include "slang-ir-dominators.h"
#include "slang-ir-clone.h"
#include "slang-ir-util.h"
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-dce.h"
#include "../core/slang-performance-profiler.h"
namespace Slang
{
static bool _eliminateDeadBlocks(List<IRBlock*>& blocks, IRBlock* unreachableBlock)
{
if (blocks.getCount() == 0)
return false;
bool changed = false;
HashSet<IRBlock*> aliveBlocks;
aliveBlocks.add(blocks[0]);
List<IRBlock*> workList;
workList.add(blocks[0]);
for (Index i = 0; i < workList.getCount(); i++)
{
auto block = workList[i];
for (auto succ : block->getSuccessors())
{
if (aliveBlocks.add(succ))
{
workList.add(succ);
}
}
}
for (auto& b : blocks)
{
if (!aliveBlocks.contains(b))
{
if (b->hasUses())
{
b->replaceUsesWith(unreachableBlock);
}
b->removeAndDeallocate();
b = nullptr;
changed = true;
}
}
return changed;
}
static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst)
{
static constexpr int kMaxIterationsToAttempt = 4096;
auto forceUnrollDecor = loopInst->findDecoration<IRForceUnrollDecoration>();
if (!forceUnrollDecor)
return -1;
int maxIterations = kMaxIterationsToAttempt;
auto maxIterCount = as<IRIntLit>(forceUnrollDecor->getOperand(0));
if (maxIterCount && maxIterCount->getValue() != 0)
{
maxIterations =
Math::Min((int)maxIterCount->getValue() + 1, kMaxIterationsToAttempt);
}
return maxIterations;
}
static void _foldAndSimplifyLoopIteration(
TargetRequest* targetRequest,
IRBuilder& builder,
List<IRBlock*>& clonedBlocks,
IRBlock* firstIterationBreakBlock,
IRBlock* unreachableBlock)
{
for (;;)
{
// Try to simplify and evaluate each inst in `firstIterationBreakBlock` and in
// cloned loop body.
for (auto b : clonedBlocks)
{
for (auto inst : b->getChildren())
{
tryReplaceInstUsesWithSimplifiedValue(targetRequest, builder.getModule(), inst);
}
}
// It is important to also evaluate `firstIterationBreakBlock` because we need to have
// the phi arguments for next iteration evaluated (args in the new loop inst).
for (auto inst : firstIterationBreakBlock->getChildren())
{
tryReplaceInstUsesWithSimplifiedValue(targetRequest, builder.getModule(), inst);
}
// Fold conditional branches into unconditional branches if the condition is known.
for (auto b : clonedBlocks)
{
auto terminator = b->getTerminator();
if (auto cbranch = as<IRConditionalBranch>(terminator))
{
if (auto constCondition = as<IRConstant>(cbranch->getCondition()))
{
auto targetBlock = (constCondition->value.intVal != 0) ? cbranch->getTrueBlock() : cbranch->getFalseBlock();
builder.setInsertBefore(cbranch);
builder.emitBranch(targetBlock);
cbranch->removeAndDeallocate();
}
}
else if (auto switchInst = as<IRSwitch>(terminator))
{
if (auto constCondition = as<IRConstant>(switchInst->condition.get()))
{
for (UInt i = 0; i < switchInst->getCaseCount(); i++)
{
if (constCondition == switchInst->getCaseValue(i))
{
builder.setInsertBefore(switchInst);
builder.emitBranch(switchInst->getCaseLabel(i));
switchInst->removeAndDeallocate();
break;
}
}
}
}
}
// DCE on CFG.
bool hasChanges = _eliminateDeadBlocks(clonedBlocks, unreachableBlock);
if (!hasChanges)
break;
// Delete removed blocks from clonedBlocks.
Index insertIndex = 0;
for (Index i = 0; i < clonedBlocks.getCount(); i++)
{
auto b = clonedBlocks[i];
if (b)
{
clonedBlocks[insertIndex] = b;
insertIndex++;
}
}
clonedBlocks.setCount(insertIndex);
}
}
// Unroll loop up to a predefined maximum number of iterations.
// Returns true if we can statically determine that the loop terminated within the iteration limit.
// This operation assumes the loop does not have `continue` jumps, i.e. continueBlock == targetBlock.
static bool _unrollLoop(
TargetRequest* targetRequest,
IRModule* module,
IRLoop* loopInst,
List<IRBlock*>& blocks)
{
if (blocks.getCount() == 0)
{
IRBuilder subBuilder(module);
subBuilder.setInsertBefore(loopInst);
subBuilder.emitBranch(loopInst->getBreakBlock());
loopInst->removeAndDeallocate();
return true;
}
auto maxIterations = _getLoopMaxIterationsToUnroll(loopInst);
if (maxIterations < 0)
return true;
// We assume all `continue`s are eliminated and turned into multi-level breaks
// before this operation.
SLANG_RELEASE_ASSERT(loopInst->getContinueBlock() == loopInst->getTargetBlock());
// Insert an outer breakable region so we have a break label to use as the target for
// any `break` jumps in the unrolled loop.
// Transform CFG from [..., loopInst] -> [loopTarget] ->... [originalLoopBreakBlock]
// Into: [..., loop] -> [outerBreakableRegionHeader, loopInst(phi_arg)] -> [(phi_param) loopTarget] -> ... ->
// [newLoopBreakBlock] -> [originalLoopBreakBlock/outerBreakableRegionBreakBlock]
// After this transform, the original break block of the loop will serve as the break block for the
// outer breakable region.
IRBuilder builder(module);
auto unreachableBlock = builder.createBlock();
builder.setInsertInto(unreachableBlock);
builder.emitUnreachable();
unreachableBlock->insertAtEnd(loopInst->parent->parent);
auto outerBreakableRegionHeader = builder.createBlock();
outerBreakableRegionHeader->insertBefore(loopInst->getTargetBlock());
auto newLoopBreakableRegionBreakBlock = builder.createBlock();
newLoopBreakableRegionBreakBlock->insertBefore(loopInst->getBreakBlock());
IRBlock* outerBreakableRegionBreakBlock = nullptr;
{
auto originalBreakBlock = loopInst->getBreakBlock();
// Since all `break`s in the original loop body will become jumps into
// `newLoopBreakableRegionBreakBlock` after unrolling, we need to make sure
// `newLoopBreakableRegionBreakBlock` contains exactly the same set of
// phi parameters as the original break block.
IRCloneEnv cloneEnv;
builder.setInsertInto(newLoopBreakableRegionBreakBlock);
List<IRInst*> newParams;
for (auto param : originalBreakBlock->getParams())
{
auto clonedParam = cloneInst(&cloneEnv, &builder, param);
newParams.add(clonedParam);
}
// Make the existing code in the loop body to jump into `newLoopBreakableRegionBreakBlock`
// instead, because we are going to make `originalBreakBlock` the new break block for
// the outer breakable region.
originalBreakBlock->replaceUsesWith(newLoopBreakableRegionBreakBlock);
builder.emitBranch(originalBreakBlock, newParams.getCount(), newParams.getBuffer());
// Use the original break block as the break block for the new outer loop.
outerBreakableRegionBreakBlock = originalBreakBlock;
// Use a loop inst to enter the breakable region. (This isn't a real loop).
builder.setInsertBefore(loopInst);
builder.emitLoop(
outerBreakableRegionHeader,
outerBreakableRegionBreakBlock,
outerBreakableRegionHeader);
// The original loop inst should now be moved into `outerBreakableRegionHeader`.
loopInst->insertAtEnd(outerBreakableRegionHeader);
}
bool loopTerminated = false;
for (int attempedIterations = 0; attempedIterations < maxIterations; attempedIterations++)
{
// Our task is to peel off the first iteration and put it in front of the
// loop.
// We will create a breakable region (via single iteration loop), and clone the loop body
// into this region. This region is defined by the header block `firstIterationLoopHeader`,
// and the converge block `firstIterationBreakBlock`.
IRCloneEnv cloneEnv;
auto loopTargetBlock = loopInst->getTargetBlock();
auto firstIterationLoopHeader = builder.createBlock();
firstIterationLoopHeader->insertBefore(loopTargetBlock);
auto firstIterationBreakBlock = builder.createBlock();
firstIterationBreakBlock->insertBefore(loopTargetBlock);
// Map loop params for first iteration to arguments, so that
// when we clone the blocks, these parameters will get replaced
// with the actual arguments.
UInt argId = 0;
for (auto param : loopTargetBlock->getParams())
{
cloneEnv.mapOldValToNew[param] = loopInst->getArg(argId);
argId++;
}
// While cloning the loop body, if we see any `break`s, we replace it with a branch
// into outerBreakableRegionBreakBlock.
// We replace the back edge with a jump into firstIterationBreakBlock.
// The original loop will start from firstIterationBreakBlock.
cloneEnv.mapOldValToNew[loopInst->getBreakBlock()] = outerBreakableRegionBreakBlock;
cloneEnv.mapOldValToNew[loopInst->getTargetBlock()] = firstIterationBreakBlock;
// Wire up the breakable region blocks.
// Note that the breakable region header will never have any phi params because there will never
// be back jumps into the header (it is a single iteration loop just for the break label).
builder.setInsertBefore(loopInst);
builder.emitLoop(firstIterationLoopHeader, firstIterationBreakBlock, firstIterationLoopHeader);
// The `firstIterationBreakBlock` is supposed to act as the `targetBlock` for the back-jump in the
// loop body. Therefore, if the original loop target block has any phi params, we will need the
// same set of phi params in `firstIterationBreakBlock` so keep those branches valid.
builder.setInsertInto(firstIterationBreakBlock);
{
IRCloneEnv paramCloneEnv;
List<IRInst*> newParams;
for (auto param : loopTargetBlock->getParams())
{
newParams.add(cloneInst(¶mCloneEnv, &builder, param));
}
// In `firstIterationBreakBlock`, we emit a new loop inst
// to start a loop for the remaining iterations.
auto newLoopInst = as<IRLoop>(builder.emitLoop(
loopTargetBlock,
loopInst->getBreakBlock(),
loopInst->getContinueBlock(),
newParams.getCount(),
newParams.getBuffer()));
loopInst->removeAndDeallocate();
// Update `loopInst` to represent the remaining loop iterations that are yet to be unrolled.
loopInst = newLoopInst;
}
// With the break region set up and wired, we can now clone the loop body into the break region.
// We create all the blocks first, and setup the clone mapping for the blocks so when we
// clone the insts later, the branch targets will automatically set to their clones.
List<IRBlock*> clonedBlocks;
for (auto b : blocks)
{
builder.setInsertBefore(firstIterationBreakBlock);
auto clonedBlock = builder.createBlock();
clonedBlock->insertBefore(firstIterationBreakBlock);
cloneEnv.mapOldValToNew.addIfNotExists(b, clonedBlock);
clonedBlocks.add(clonedBlock);
}
// Now clone the insts inside each block.
for (Index i = 0; i < blocks.getCount(); i++)
{
auto originalBlock = blocks[i];
auto clonedBlock = clonedBlocks[i];
builder.setInsertInto(clonedBlock);
for (auto inst : originalBlock->getChildren())
{
cloneInst(&cloneEnv, &builder, inst);
}
}
// Wire the break region header to jump to the first loop body block.
builder.setInsertInto(firstIterationLoopHeader);
builder.emitBranch(clonedBlocks[0]);
// Cloned first block of the iteration should not have any params,
// they must have been replaced with actual arguments since we have set up
// the mappings for them before the clone.
SLANG_RELEASE_ASSERT(clonedBlocks[0]->getFirstParam() == nullptr);
// With all the insts for the first iteration in place, we now iteratively run
// SCCP and simplification for the cloned blocks, in hope that some
// conditional jumps can be folded into unconditional jumps.
_foldAndSimplifyLoopIteration(
targetRequest, builder, clonedBlocks, firstIterationBreakBlock, unreachableBlock);
// Now we have peeled off one iteration from the loop, we check if there are any
// branches into next iteration, if not, the loop terminates and we are done.
bool hasJumpsToRemainingLoop = false;
for (auto b : clonedBlocks)
{
for (auto succ : b->getSuccessors())
{
if (succ == firstIterationBreakBlock)
{
hasJumpsToRemainingLoop = true;
break;
}
}
}
if (!hasJumpsToRemainingLoop)
{
loopTerminated = true;
// Now we know the loop terminates and we have just emitted the last iteration.
// We need to replace all uses of the insts defined within the loop body with their
// clones in the last iteration.
HashSet<IRBlock*> blockSet;
for (auto block : blocks)
{
blockSet.add(block);
}
for (auto block : blocks)
{
for (auto inst : block->getChildren())
{
IRInst* newInst = nullptr;
if (!cloneEnv.mapOldValToNew.tryGetValue(inst, newInst))
continue;
for (auto use = inst->firstUse; use;)
{
auto nextUse = use->nextUse;
if (!blockSet.contains(as<IRBlock>(use->getUser()->getParent())))
{
use->set(newInst);
}
use = nextUse;
}
}
}
// Now we can safely delete the original loop blocks.
for (auto block : blocks)
{
block->replaceUsesWith(unreachableBlock);
block->removeAndDeallocate();
}
// firstIterationBreakBlock is no longer reachable, so we can delete its children
// and turn it into an unreachable block.
firstIterationBreakBlock->removeAndDeallocateAllDecorationsAndChildren();
builder.setInsertInto(firstIterationBreakBlock);
builder.emitUnreachable();
break;
}
}
return loopTerminated;
}
// Visits all loop insts in a func, inner loop first.
template<typename TFunc>
List<IRLoop*> collectLoopsInFunc(IRGlobalValueWithCode* func, const TFunc& filter)
{
List<IRLoop*> loops;
// Post order processing allows us to process inner loops first.
auto postOrder = getPostorder(func);
for (auto block : postOrder)
{
if (auto loop = as<IRLoop>(block->getTerminator()))
{
if (filter(loop))
{
loops.add(loop);
}
}
}
return loops;
}
bool unrollLoopsInFunc(
TargetRequest* targetRequest,
IRModule* module,
IRGlobalValueWithCode* func,
DiagnosticSink* sink)
{
List<IRLoop*> loops = collectLoopsInFunc(
func, [](IRLoop* l) { return l->findDecoration<IRForceUnrollDecoration>() != nullptr; });
if (loops.getCount() == 0)
return true;
for (auto loop : loops)
{
// Remove any continue jumps from the loop.
eliminateContinueBlocks(module, loop);
auto blocks = collectBlocksInRegion(func, loop);
auto loopLoc = loop->sourceLoc;
if (!_unrollLoop(targetRequest, module, loop, blocks))
{
if (sink)
sink->diagnose(loopLoc, Diagnostics::cannotUnrollLoop);
return false;
}
// Make sure we simplify things as much as possible before
// attempting to potentially unroll outer loop.
simplifyCFG(func, CFGSimplificationOptions::getDefault());
eliminateDeadCode(func);
}
return true;
}
bool unrollLoopsInModule(TargetRequest* targetRequest, IRModule* module, DiagnosticSink* sink)
{
SLANG_PROFILE;
for (auto inst : module->getGlobalInsts())
{
if (as<IRGeneric>(inst))
continue;
if (auto func = as<IRGlobalValueWithCode>(inst))
{
bool result = unrollLoopsInFunc(targetRequest, module, func, sink);
if (!result)
return false;
}
}
return true;
}
void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst)
{
// Eliminate the continue jumps by turning a loop in the form of:
// for (;;)
// {
// <loop body>
// continueBlock:
// <continuePart>
// }
// into:
// for (;;) // original loop
// {
// for(;;) // breakableRegionHeader
// {
// <loop body>
// }
// breakableRegionBreakBlock:
// <continuePart>
// }
// where a continue is replaced with a "break" into breakableRegionBreakBlock.
//
auto continueBlock = loopInst->getContinueBlock();
if (continueBlock == loopInst->getTargetBlock())
return;
// If the continue block is not reachable, remove it.
if (continueBlock && !continueBlock->hasMoreThanOneUse())
{
loopInst->continueBlock.set(loopInst->getTargetBlock());
continueBlock->removeAndDeallocate();
return;
}
// We have determined that there is really a non-trivial continue block in the loop body,
// we will now introduce a breakable region for each iteration.
IRBuilder builder(module);
auto targetBlock = loopInst->getTargetBlock();
auto innerBreakableRegionHeader = builder.createBlock();
innerBreakableRegionHeader->insertBefore(targetBlock);
auto innerBreakableRegionBreakBlock = builder.createBlock();
innerBreakableRegionBreakBlock->insertBefore(continueBlock);
loopInst->block.set(innerBreakableRegionHeader);
loopInst->continueBlock.set(innerBreakableRegionHeader);
targetBlock->replaceUsesWith(innerBreakableRegionHeader);
// Move decorations and params from original targetBlock to innerBreakableRegionHeader.
moveParams(innerBreakableRegionHeader, targetBlock);
builder.setInsertInto(innerBreakableRegionHeader);
builder.emitLoop(targetBlock, innerBreakableRegionBreakBlock, targetBlock);
continueBlock->replaceUsesWith(innerBreakableRegionBreakBlock);
builder.setInsertInto(innerBreakableRegionBreakBlock);
moveParams(innerBreakableRegionBreakBlock, continueBlock);
builder.emitBranch(continueBlock);
// If the original loop can be executed up to N times, the new loop may be executed
// upto N+1 times (although most insts are skipped in the last traversal)
//
if (auto maxItersDecoration = loopInst->findDecoration<IRLoopMaxItersDecoration>())
{
auto maxIters = maxItersDecoration->getMaxIters();
maxItersDecoration->removeAndDeallocate();
builder.addLoopMaxItersDecoration(loopInst, maxIters + 1);
}
}
void eliminateContinueBlocksInFunc(IRModule* module, IRGlobalValueWithCode* func)
{
List<IRLoop*> loops = collectLoopsInFunc(
func, [](IRLoop* l) { return l->getContinueBlock() != l->getTargetBlock(); });
if (loops.getCount() == 0)
return;
for (auto loop : loops)
{
eliminateContinueBlocks(module, loop);
}
}
}