https://github.com/shader-slang/slang
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-peephole.cpp
#include "slang-ir-peephole.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-sccp.h"
#include "slang-ir-dominators.h"
#include "slang-ir-util.h"
#include "slang-ir-layout.h"
namespace Slang
{
struct PeepholeContext : InstPassBase
{
PeepholeContext(IRModule* inModule)
: InstPassBase(inModule)
{}
bool changed = false;
FloatingPointMode floatingPointMode = FloatingPointMode::Precise;
bool removeOldInst = true;
bool isInGeneric = false;
TargetRequest* targetRequest;
void maybeRemoveOldInst(IRInst* inst)
{
if (removeOldInst)
inst->removeAndDeallocate();
}
bool tryFoldElementExtractFromUpdateInst(IRInst* inst)
{
bool isAccessChainEqual = false;
bool isAccessChainNotEqual = false;
List<IRInst*> chainKey;
IRInst* chainNode = inst;
for (;;)
{
switch (chainNode->getOp())
{
case kIROp_FieldExtract:
case kIROp_GetElement:
chainKey.add(chainNode->getOperand(1));
chainNode = chainNode->getOperand(0);
continue;
}
break;
}
chainKey.reverse();
if (auto updateInst = as<IRUpdateElement>(chainNode))
{
// If we see an extract(updateElement(x, accessChain, val), accessChain), then
// we can replace the inst with val.
if (updateInst->getAccessKeyCount() > (UInt)chainKey.getCount())
return false;
isAccessChainEqual = true;
for (UInt i = 0; i < updateInst->getAccessKeyCount(); i++)
{
if (updateInst->getAccessKey(i) != chainKey[i])
{
isAccessChainEqual = false;
if (as<IRStructKey>(chainKey[i]))
{
isAccessChainNotEqual = true;
break;
}
else
{
if (auto constIndex1 = as<IRIntLit>(updateInst->getAccessKey(i)))
{
if (auto constIndex2 = as<IRIntLit>(chainKey[i]))
{
if (constIndex1->getValue() != constIndex2->getValue())
{
isAccessChainNotEqual = true;
break;
}
}
}
}
}
}
if (isAccessChainEqual)
{
auto remainingKeys = chainKey.getArrayView(
updateInst->getAccessKeyCount(),
chainKey.getCount() - updateInst->getAccessKeyCount());
if (remainingKeys.getCount() == 0)
{
inst->replaceUsesWith(updateInst->getElementValue());
maybeRemoveOldInst(inst);
return true;
}
else if (remainingKeys.getCount() > 0)
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto newValue = builder.emitElementExtract(updateInst->getElementValue(), remainingKeys);
inst->replaceUsesWith(newValue);
maybeRemoveOldInst(inst);
return true;
}
}
else if (isAccessChainNotEqual)
{
// If we see an extract(updateElement(x, accessChain, val), accessChain2), where accessChain!=accessChain2,
// then we can replace the inst with extract(x, accessChain2).
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto newInst = builder.emitElementExtract(updateInst->getOldValue(), chainKey.getArrayView());
inst->replaceUsesWith(newInst);
maybeRemoveOldInst(inst);
return true;
}
}
return false;
}
bool tryOptimizeArithmeticInst(IRInst* inst)
{
bool allowUnsafeOptimizations =
(floatingPointMode == FloatingPointMode::Fast ||
isIntegralScalarOrCompositeType(inst->getDataType()));
auto tryReplace = [&](IRInst* replacement) -> bool
{
if (replacement->getFullType() != inst->getFullType())
{
// If the operand type is different from result type,
// we try to convert for some known cases.
if (auto vectorType = as<IRVectorType>(inst->getFullType()))
{
if (vectorType->getElementType() != replacement->getFullType())
return false;
IRBuilder builder(module);
builder.setInsertBefore(inst);
replacement = builder.emitMakeVectorFromScalar(inst->getFullType(), replacement);
}
else
{
return false;
}
}
inst->replaceUsesWith(replacement);
maybeRemoveOldInst(inst);
return true;
};
switch (inst->getOp())
{
case kIROp_Add:
if (isZero(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(1));
}
else if (isZero(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(0));
}
break;
case kIROp_Sub:
if (isZero(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(0));
}
else if (inst->getOperand(0) == inst->getOperand(1))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
return tryReplace(builder.emitDefaultConstruct(inst->getDataType()));
}
break;
case kIROp_Mul:
if (isOne(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(1));
}
else if (isOne(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(0));
}
else if (allowUnsafeOptimizations && isZero(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(0));
}
else if (allowUnsafeOptimizations && isZero(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(1));
}
break;
case kIROp_Div:
if (allowUnsafeOptimizations && isZero(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(0));
}
else if (isOne(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(0));
}
break;
case kIROp_And:
if (isZero(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(0));
}
else if (isZero(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(1));
}
else if (isOne(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(0));
}
else if (isOne(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(1));
}
break;
case kIROp_Or:
if (isZero(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(1));
}
else if (isZero(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(0));
}
else if (isOne(inst->getOperand(1)))
{
return tryReplace(inst->getOperand(1));
}
else if (isOne(inst->getOperand(0)))
{
return tryReplace(inst->getOperand(0));
}
break;
}
return false;
}
void processInst(IRInst* inst)
{
if (as<IRGlobalValueWithCode>(inst))
{
if (auto fpModeDecor = inst->findDecoration<IRFloatingModeOverrideDecoration>())
floatingPointMode = fpModeDecor->getFloatingPointMode();
}
switch (inst->getOp())
{
case kIROp_GetResultError:
if (inst->getOperand(0)->getOp() == kIROp_MakeResultError)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
break;
case kIROp_GetResultValue:
if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
break;
case kIROp_IsResultError:
if (inst->getOperand(0)->getOp() == kIROp_MakeResultError)
{
IRBuilder builder(module);
inst->replaceUsesWith(builder.getBoolValue(true));
maybeRemoveOldInst(inst);
changed = true;
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeResultValue)
{
IRBuilder builder(module);
inst->replaceUsesWith(builder.getBoolValue(false));
maybeRemoveOldInst(inst);
changed = true;
}
break;
case kIROp_GetTupleElement:
if (inst->getOperand(0)->getOp() == kIROp_MakeTuple)
{
auto element = inst->getOperand(1);
if (auto intLit = as<IRIntLit>(element))
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)intLit->value.intVal));
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_FieldExtract:
if (inst->getOperand(0)->getOp() == kIROp_MakeStruct)
{
auto field = as<IRFieldExtract>(inst)->field.get();
Index fieldIndex = -1;
auto structType = as<IRStructType>(inst->getOperand(0)->getDataType());
if (structType)
{
Index i = 0;
for (auto sfield : structType->getFields())
{
if (sfield->getKey() == field)
{
fieldIndex = i;
break;
}
i++;
}
if (fieldIndex != -1 && fieldIndex < (Index)inst->getOperand(0)->getOperandCount())
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)fieldIndex));
maybeRemoveOldInst(inst);
changed = true;
}
}
}
else
{
changed |= tryFoldElementExtractFromUpdateInst(inst);
}
break;
case kIROp_GetElement:
if (inst->getOperand(0)->getOp() == kIROp_MakeArray)
{
auto index = as<IRIntLit>(as<IRGetElement>(inst)->getIndex());
if (!index)
break;
auto opCount = inst->getOperand(0)->getOperandCount();
if ((UInt)index->getValue() < opCount)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)index->getValue()));
maybeRemoveOldInst(inst);
changed = true;
}
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeVector)
{
auto index = as<IRIntLit>(as<IRGetElement>(inst)->getIndex());
if (!index)
break;
auto opCount = inst->getOperand(0)->getOperandCount();
IRIntegerValue startIndex = 0;
for (UInt i = 0; i < opCount; i++)
{
auto element = inst->getOperand(0)->getOperand(i);
if (auto elementVecType = as<IRVectorType>(element->getDataType()))
{
auto vecSize = as<IRIntLit>(elementVecType->getElementCount());
if (!vecSize)
break;
if (index->getValue() >= startIndex && index->getValue() < startIndex + vecSize->getValue())
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto newElement = builder.emitElementExtract(element, builder.getIntValue(builder.getIntType(), index->getValue() - startIndex));
inst->replaceUsesWith(newElement);
maybeRemoveOldInst(inst);
changed = true;
break;
}
startIndex += vecSize->getValue();
}
else
{
if (startIndex == index->getValue())
{
inst->replaceUsesWith(element);
maybeRemoveOldInst(inst);
changed = true;
break;
}
startIndex++;
}
}
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeArrayFromElement || inst->getOperand(0)->getOp() == kIROp_MakeVectorFromScalar)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
else
{
changed |= tryFoldElementExtractFromUpdateInst(inst);
}
break;
case kIROp_UpdateElement:
{
auto updateInst = as<IRUpdateElement>(inst);
if (updateInst->getAccessKeyCount() != 1)
break;
auto key = updateInst->getAccessKey(0);
if (auto constIndex = as<IRIntLit>(key))
{
auto oldVal = inst->getOperand(0);
if (oldVal->getOp() == kIROp_MakeArray ||
oldVal->getOp() == kIROp_MakeArrayFromElement)
{
auto arrayType = as<IRArrayType>(inst->getDataType());
if (!arrayType) break;
auto arraySize = as<IRIntLit>(arrayType->getElementCount());
if (!arraySize) break;
List<IRInst*> args;
for (IRIntegerValue i = 0; i < arraySize->getValue(); i++)
{
IRInst* arg = nullptr;
if (i < (IRIntegerValue)oldVal->getOperandCount())
arg = oldVal->getOperand((UInt)i);
else if (oldVal->getOperandCount() != 0)
arg = oldVal->getOperand(0);
else
break;
if (i == (IRIntegerValue)constIndex->getValue())
arg = updateInst->getElementValue();
args.add(arg);
}
if (args.getCount() == arraySize->getValue())
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto makeArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
inst->replaceUsesWith(makeArray);
maybeRemoveOldInst(inst);
changed = true;
}
}
else
{
// Check if the updated value is a chain of `updateElement` instructions that
// updates every element in the same array, and if so we can replace the
// whole chain with a single `makeArray` instruction.
auto arrayType = as<IRArrayType>(inst->getDataType());
if (!arrayType) break;
auto arraySize = as<IRIntLit>(arrayType->getElementCount());
if (!arraySize) break;
List<IRInst*> args;
args.setCount((UInt)arraySize->getValue());
for (Index i = 0; i < args.getCount(); i++)
args[i] = nullptr;
for (auto updateElement = updateInst; updateElement;
updateElement = as<IRUpdateElement>(updateElement->getOldValue()))
{
auto subKey = updateElement->getAccessKey(0);
auto subConstIndex = as<IRIntLit>(subKey);
if (!subConstIndex)
break;
auto index = (Index)subConstIndex->getValue();
if (index >= args.getCount())
break;
// If we have already seen an update for this index, then we can't
// override it with an earlier update.
if (args[index])
continue;
args[index] = updateElement->getElementValue();
}
bool isComplete = true;
for (auto arg : args)
{
if (!arg)
{
isComplete = false;
break;
}
}
if (isComplete)
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto makeArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
inst->replaceUsesWith(makeArray);
maybeRemoveOldInst(inst);
changed = true;
}
}
}
else if (const auto structKey = as<IRStructKey>(key))
{
auto oldVal = inst->getOperand(0);
if (oldVal->getOp() == kIROp_MakeStruct)
{
// If we see updateElement(makeStruct(...), structKey, ...), we can
// replace it with a makeStruct that has the updated value.
auto structType = as<IRStructType>(inst->getDataType());
if (!structType) break;
List<IRInst*> args;
UInt i = 0;
bool isValid = true;
for (auto field : structType->getFields())
{
IRInst* arg = nullptr;
if (i < oldVal->getOperandCount())
arg = oldVal->getOperand(i);
if (field->getKey() == key)
arg = updateInst->getElementValue();
if (arg)
{
args.add(arg);
}
else
{
isValid = false;
break;
}
i++;
}
if (isValid)
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
inst->replaceUsesWith(makeStruct);
maybeRemoveOldInst(inst);
changed = true;
}
}
else
{
// Check if the updated `oldVal` is a chain of updateElement insts that assigns
// values to every field of the struct, if so, we can just emit a makeStruct instead.
Dictionary<IRStructKey*, IRInst*> mapFieldKeyToVal;
for (auto updateElement = as<IRUpdateElement>(inst); updateElement;
updateElement = as<IRUpdateElement>(updateElement->getOldValue()))
{
if (updateElement->getAccessKeyCount() != 1)
break;
auto subStructKey = as<IRStructKey>(updateElement->getAccessKey(0));
if (!subStructKey)
break;
// If the key already exists, it means there is already a later update at this key.
// We need to be careful not to override it with an earlier value.
// AddIfNotExists will ensure this does not happen.
mapFieldKeyToVal.addIfNotExists(
subStructKey, updateElement->getElementValue());
}
// Check if every field of the struct has a value assigned to it,
// while build up arguments for makeStruct inst at the same time.
auto structType = as<IRStructType>(inst->getDataType());
if (!structType) break;
List<IRInst*> args;
bool isComplete = true;
for (auto field : structType->getFields())
{
IRInst* arg = nullptr;
if (mapFieldKeyToVal.tryGetValue(field->getKey(), arg))
{
args.add(arg);
}
else
{
isComplete = false;
break;
}
}
if (!isComplete) break;
// Create a makeStruct inst using args.
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
inst->replaceUsesWith(makeStruct);
maybeRemoveOldInst(inst);
changed = true;
}
}
}
break;
case kIROp_CastPtrToBool:
{
auto ptr = inst->getOperand(0);
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto neq = builder.emitNeq(ptr, builder.getNullVoidPtrValue());
inst->replaceUsesWith(neq);
maybeRemoveOldInst(inst);
changed = true;
}
break;
case kIROp_IsType:
{
auto isTypeInst = as<IRIsType>(inst);
auto actualType = isTypeInst->getValue()->getDataType();
if (isTypeEqual(actualType, (IRType*)isTypeInst->getTypeOperand()))
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto trueVal = builder.getBoolValue(true);
inst->replaceUsesWith(trueVal);
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_Reinterpret:
case kIROp_BitCast:
case kIROp_IntCast:
case kIROp_FloatCast:
{
if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType()))
{
inst->replaceUsesWith(inst->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_UnpackAnyValue:
{
if (inst->getOperand(0)->getOp() == kIROp_PackAnyValue)
{
if (isTypeEqual(inst->getOperand(0)->getOperand(0)->getDataType(), inst->getDataType()))
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
}
}
break;
case kIROp_PackAnyValue:
{
// Pack(obj: anyValueN) : anyValueN --> obj
if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType()))
{
inst->replaceUsesWith(inst->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_GetOptionalValue:
{
if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_OptionalHasValue:
{
if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue)
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto trueVal = builder.getBoolValue(true);
inst->replaceUsesWith(trueVal);
maybeRemoveOldInst(inst);
changed = true;
}
else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone)
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto falseVal = builder.getBoolValue(false);
inst->replaceUsesWith(falseVal);
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_GetNativePtr:
{
if (inst->getOperand(0)->getOp() == kIROp_PtrLit)
{
inst->replaceUsesWith(inst->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_MakeExistential:
{
if (inst->getOperand(0)->getOp() == kIROp_ExtractExistentialValue)
{
inst->replaceUsesWith(inst->getOperand(0)->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_LookupWitness:
{
if (inst->getOperand(0)->getOp() == kIROp_WitnessTable)
{
auto wt = as<IRWitnessTable>(inst->getOperand(0));
auto key = inst->getOperand(1);
for (auto item : wt->getChildren())
{
if (auto entry = as<IRWitnessTableEntry>(item))
{
if (entry->getRequirementKey() == key)
{
auto value = entry->getSatisfyingVal();
inst->replaceUsesWith(value);
inst->removeAndDeallocate();
changed = true;
break;
}
}
}
}
}
break;
case kIROp_DefaultConstruct:
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
// See if we can replace the default construct inst with concrete values.
if (auto newCtor = builder.emitDefaultConstruct(inst->getFullType(), false))
{
inst->replaceUsesWith(newCtor);
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_VectorReshape:
{
auto fromType = as<IRVectorType>(inst->getOperand(0)->getDataType());
auto resultType = as<IRVectorType>(inst->getDataType());
if (!resultType)
{
if (!fromType)
{
inst->replaceUsesWith(inst->getOperand(0));
maybeRemoveOldInst(inst);
changed = true;
break;
}
IRBuilder builder(inst);
builder.setInsertBefore(inst);
UInt index = 0;
auto newInst = builder.emitSwizzle(resultType, inst->getOperand(0), 1, &index);
inst->replaceUsesWith(newInst);
maybeRemoveOldInst(inst);
changed = true;
break;
}
auto fromCount = as<IRIntLit>(fromType->getElementCount());
if (!fromCount)
break;
auto toCount = as<IRIntLit>(resultType->getElementCount());
if (!toCount)
break;
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto newInst = builder.emitVectorReshape(resultType, inst->getOperand(0));
if (newInst != inst)
{
inst->replaceUsesWith(newInst);
maybeRemoveOldInst(inst);
changed = true;
}
}
break;
case kIROp_MatrixReshape:
{
auto fromType = as<IRMatrixType>(inst->getOperand(0)->getDataType());
auto resultType = as<IRMatrixType>(inst->getDataType());
SLANG_ASSERT(fromType && resultType);
auto fromRows = as<IRIntLit>(fromType->getRowCount());
if (!fromRows) break;
auto fromCols = as<IRIntLit>(fromType->getColumnCount());
if (!fromCols) break;
auto toRows = as<IRIntLit>(resultType->getRowCount());
if (!toRows) break;
auto toCols = as<IRIntLit>(resultType->getColumnCount());
if (!toCols) break;
List<IRInst*> rows;
IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto toRowType = builder.getVectorType(resultType->getElementType(), resultType->getColumnCount());
for (IRIntegerValue i = 0; i < toRows->getValue(); i++)
{
if (i < fromRows->getValue())
{
auto originalRow = builder.emitElementExtract(inst->getOperand(0), i);
auto resizedRow = builder.emitVectorReshape(toRowType, originalRow);
rows.add(resizedRow);
}
else
{
auto zero = builder.emitDefaultConstruct(resultType->getElementType());
auto row = builder.emitMakeVectorFromScalar(toRowType, zero);
rows.add(row);
}
}
auto newInst = builder.emitMakeMatrix(resultType, (UInt)rows.getCount(), rows.getBuffer());
inst->replaceUsesWith(newInst);
maybeRemoveOldInst(inst);
changed = true;
}
break;
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
case kIROp_Div:
case kIROp_And:
case kIROp_Or:
changed |= tryOptimizeArithmeticInst(inst);
break;
case kIROp_Param:
{
auto block = as<IRBlock>(inst->parent);
if (!block)
break;
UInt paramIndex = 0;
auto prevParam = inst->getPrevInst();
while (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(prevParam))
{
prevParam = prevParam->getPrevInst();
paramIndex++;
}
IRInst* argValue = nullptr;
for (auto pred : block->getPredecessors())
{
auto terminator = as<IRUnconditionalBranch>(pred->getTerminator());
if (!terminator)
continue;
SLANG_ASSERT(terminator->getArgCount() > paramIndex);
auto arg = terminator->getArg(paramIndex);
if (arg->getOp() == kIROp_undefined)
continue;
if (argValue == nullptr)
argValue = arg;
else if (argValue == arg)
{
}
else
{
argValue = nullptr;
break;
}
}
if (argValue)
{
if (inst->hasUses())
{
// Is argValue not a local value, i.e. it's not a child
// of a block, and it's 'visible' from inst because
// inst is a descendent of argValue's parent
if (!as<IRBlock>(argValue->getParent())
&& isChildInstOf(inst, argValue->getParent()))
{
inst->replaceUsesWith(argValue);
// Never remove param inst.
changed = true;
}
else
{
// If argValue is defined locally,
// we can replace only if argVal dominates inst.
auto parentFunc = getParentFunc(inst);
if (!parentFunc)
break;
auto domTree = parentFunc->getModule()->findOrCreateDominatorTree(parentFunc);
if (domTree->dominates(argValue, inst))
{
inst->replaceUsesWith(argValue);
// Never remove param inst.
changed = true;
}
}
}
}
}
break;
case kIROp_swizzle:
{
// If we see a swizzle(scalar), we replace it with makeVectorFromScalar.
if (as<IRBasicType>(inst->getOperand(0)->getDataType()))
{
auto vectorType = as<IRVectorType>(inst->getDataType());
IRIntegerValue vectorSize = 1;
if (vectorType)
{
auto sizeLit = as<IRIntLit>(vectorType->getElementCount());
if (!sizeLit)
vectorSize = 0;
vectorSize = sizeLit->getValue();
}
if (vectorSize == 1)
{
inst->replaceUsesWith(inst->getOperand(0));
maybeRemoveOldInst(inst);
break;
}
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto newInst = builder.emitMakeVectorFromScalar(vectorType, inst->getOperand(0));
inst->replaceUsesWith(newInst);
maybeRemoveOldInst(inst);
break;
}
// If we see a swizzle(makeVector) then we can replace it with the values from makeVector.
auto makeVector = inst->getOperand(0);
if (makeVector->getOp() != kIROp_MakeVector)
break;
auto swizzle = as<IRSwizzle>(inst);
List<IRInst*> vals;
auto vectorType = as<IRVectorType>(makeVector->getDataType());
auto vectorSize = as<IRIntLit>(vectorType->getElementCount());
if (!vectorSize)
break;
if (makeVector->getOperandCount() != (UInt)vectorSize->getValue())
break;
for (UInt i = 0; i < swizzle->getElementCount(); i++)
{
auto index = swizzle->getElementIndex(i);
auto intLitIndex = as<IRIntLit>(index);
if (!intLitIndex)
return;
if (intLitIndex->getValue() < (Int)makeVector->getOperandCount())
vals.add(makeVector->getOperand((UInt)intLitIndex->getValue()));
else
return;
}
if (vals.getCount() == 1)
{
inst->replaceUsesWith(vals[0]);
maybeRemoveOldInst(inst);
changed = true;
}
else
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto newMakeVector = builder.emitMakeVector(
swizzle->getDataType(), (UInt)vals.getCount(), vals.getBuffer());
inst->replaceUsesWith(newMakeVector);
maybeRemoveOldInst(inst);
changed = true;
}
break;
}
case kIROp_TypeEquals:
{
auto left = inst->getOperand(0)->getDataType();
auto right = inst->getOperand(1)->getDataType();
if (isConcreteType(left) && isConcreteType(right))
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
bool result = left == right;
inst->replaceUsesWith(builder.getBoolValue(result));
maybeRemoveOldInst(inst);
changed = true;
}
break;
}
case kIROp_GetNaturalStride:
{
if (targetRequest)
{
if (isInGeneric)
break;
auto type = inst->getOperand(0)->getDataType();
IRSizeAndAlignment sizeAlignment;
getNaturalSizeAndAlignment(targetRequest, type, &sizeAlignment);
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto stride = builder.getIntValue(inst->getDataType(), sizeAlignment.getStride());
inst->replaceUsesWith(stride);
maybeRemoveOldInst(inst);
changed = true;
}
break;
}
case kIROp_IsInt:
case kIROp_IsFloat:
case kIROp_IsUnsignedInt:
case kIROp_IsSignedInt:
case kIROp_IsBool:
case kIROp_IsVector:
{
auto type = inst->getOperand(0)->getDataType();
if (auto vectorType = as<IRVectorType>(type))
type = vectorType->getElementType();
if (auto matType = as<IRMatrixType>(type))
type = matType->getElementType();
if (isConcreteType(type))
{
IRBuilder builder(module);
builder.setInsertBefore(inst);
bool result = false;
switch (inst->getOp())
{
case kIROp_IsInt:
result = isIntegralType(type);
break;
case kIROp_IsBool:
result = type->getOp() == kIROp_BoolType;
break;
case kIROp_IsFloat:
result = isFloatingType(type);
break;
case kIROp_IsUnsignedInt:
result = isIntegralType(type) && !getIntTypeInfo(type).isSigned;
break;
case kIROp_IsSignedInt:
result = isIntegralType(type) && getIntTypeInfo(type).isSigned;
break;
case kIROp_IsVector:
result = as<IRVectorType>(type);
break;
}
inst->replaceUsesWith(builder.getBoolValue(result));
maybeRemoveOldInst(inst);
changed = true;
}
break;
}
default:
break;
}
}
bool isConcreteType(IRType* type)
{
return type->parent->getOp() == kIROp_Module && !as<IRGlobalGenericParam>(type);
}
bool processFunc(IRInst* func)
{
func->getModule()->invalidateAllAnalysis();
bool lastIsInGeneric = isInGeneric;
if (!isInGeneric)
isInGeneric = as<IRGeneric>(func) != nullptr;
bool result = false;
for (;;)
{
changed = false;
processChildInsts(func, [this](IRInst* inst) { processInst(inst); });
if (changed)
result = true;
else
break;
}
isInGeneric = lastIsInGeneric;
return result;
}
bool processModule()
{
return processFunc(module->getModuleInst());
}
};
bool peepholeOptimize(TargetRequest* target, IRModule* module)
{
PeepholeContext context = PeepholeContext(module);
context.targetRequest = target;
return context.processModule();
}
bool peepholeOptimize(TargetRequest* target, IRInst* func)
{
PeepholeContext context = PeepholeContext(func->getModule());
context.targetRequest = target;
return context.processFunc(func);
}
bool peepholeOptimizeGlobalScope(TargetRequest* target, IRModule* module)
{
PeepholeContext context = PeepholeContext(module);
context.targetRequest = target;
bool result = false;
for (;;)
{
context.changed = false;
for (auto globalInst : module->getGlobalInsts())
context.processInst(globalInst);
result |= context.changed;
if (!context.changed)
break;
}
return result;
}
bool tryReplaceInstUsesWithSimplifiedValue(TargetRequest* target, IRModule* module, IRInst* inst)
{
if (inst != tryConstantFoldInst(module, inst))
return true;
PeepholeContext context = PeepholeContext(inst->getModule());
context.targetRequest = target;
context.removeOldInst = false;
context.processInst(inst);
return context.changed;
}
} // namespace Slang