https://github.com/shader-slang/slang
Tip revision: e59516fa8c3a16eb7b99a928c5b85b97bf44fd72 authored by Yong He on 01 February 2022, 00:26:03 UTC
Revise entrypoint renaming interface. (#2113)
Revise entrypoint renaming interface. (#2113)
Tip revision: e59516f
slang-hlsl-intrinsic-set.cpp
// slang-hlsl-intrinsic-set.cpp
#include "slang-hlsl-intrinsic-set.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
namespace Slang
{
/* static */const HLSLIntrinsic::Info HLSLIntrinsic::s_operationInfos[] =
{
#define SLANG_HLSL_INTRINSIC_OP_INFO(x, funcName, numOperands) { UnownedStringSlice::fromLiteral(#x), UnownedStringSlice::fromLiteral(funcName), int8_t(numOperands) },
SLANG_HLSL_INTRINSIC_OP(SLANG_HLSL_INTRINSIC_OP_INFO)
};
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicSet !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
HLSLIntrinsicSet::HLSLIntrinsicSet(IRTypeSet* typeSet, HLSLIntrinsicOpLookup* lookup):
m_intrinsicFreeList(sizeof(HLSLIntrinsic), SLANG_ALIGN_OF(HLSLIntrinsic), 1024),
m_typeSet(typeSet),
m_opLookup(lookup)
{
}
static IRBasicType* _getElementType(IRType* type)
{
switch (type->getOp())
{
case kIROp_VectorType: type = static_cast<IRVectorType*>(type)->getElementType(); break;
case kIROp_MatrixType: type = static_cast<IRMatrixType*>(type)->getElementType(); break;
default: break;
}
return dynamicCast<IRBasicType>(type);
}
void HLSLIntrinsicSet::_calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgs, Index argsCount, HLSLIntrinsic& out)
{
IRBuilder& builder = m_typeSet->getBuilder();
// Check all types belong to the module
IRModule* module = builder.getModule();
SLANG_UNUSED(module);
SLANG_ASSERT(returnType->getModule() == module);
for (Index i = 0; i < argsCount; ++i)
{
SLANG_ASSERT(inArgs[i]->getModule() == module);
}
// Set up the out
out.op = op;
out.returnType = returnType;
switch (op)
{
case Op::GetAt:
{
IRType* argTypes[3];
SLANG_ASSERT(argsCount == 2 || argsCount == 3);
// TODO(JS):
// HACK! GetAt can be from getElementPtr or from getElement. Get element ptr means the return type will be
// a pointer. We don't want to deal with that, so strip it
if (returnType->getOp() == kIROp_PtrType)
{
returnType = as<IRType>(returnType->getOperand(0));
}
// TODO(JS): Similarly for the input parameters
for (Index i = 0; i < argsCount; ++i)
{
IRType* argType = inArgs[i];
if (argType->getOp() == kIROp_PtrType)
{
argType = as<IRType>(argType->getOperand(0));
}
argTypes[i] = argType;
}
out.returnType = returnType;
out.signatureType = builder.getFuncType(argsCount, argTypes, builder.getVoidType());
break;
}
case Op::ConstructFromScalar:
{
//SLANG_ASSERT(argsCount == 1);
SLANG_ASSERT(argsCount == 1);
IRType* srcType = _getElementType(returnType);
IRType* argTypes[2] = { returnType, srcType };
out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType());
break;
}
case Op::ConstructConvert:
{
// Make the return type a parameter, to make the signature take into account
SLANG_ASSERT(argsCount == 1);
IRType* argTypes[2] = { returnType, inArgs[0] };
out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType());
break;
}
default:
{
out.signatureType = builder.getFuncType(argsCount, inArgs, builder.getVoidType());
break;
}
}
}
void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgTypes, Index argCount, HLSLIntrinsic& out)
{
returnType = m_typeSet->getType(returnType);
if (argCount <= 8)
{
IRType* args[8];
for (Index i = 0; i < argCount; ++i)
{
args[i] = m_typeSet->getType(inArgTypes[i]);
}
_calcIntrinsic(op, returnType, args, argCount, out);
}
else
{
List<IRType*> args;
args.setCount(argCount);
for (Index i = 0; i < argCount; ++i)
{
args[i] = m_typeSet->getType(inArgTypes[i]);
}
_calcIntrinsic(op, returnType, args.getBuffer(), argCount, out);
}
}
void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRInst* inst, Index operandCount, HLSLIntrinsic& out)
{
IRType* returnType = m_typeSet->getType(inst->getDataType());
if (operandCount <= 8)
{
IRType* argTypes[8];
for (Index i = 0; i < operandCount; ++i)
{
auto operand = inst->getOperand(i);
argTypes[i] = m_typeSet->getType(operand->getDataType());
}
_calcIntrinsic(op, returnType, argTypes, operandCount, out);
}
else
{
List<IRType*> argTypes;
argTypes.setCount(operandCount);
for (Index i = 0; i < operandCount; ++i)
{
auto operand = inst->getOperand(i);
argTypes[i] = m_typeSet->getType(operand->getDataType());
}
_calcIntrinsic(op, returnType, argTypes.getBuffer(), operandCount, out);
}
}
void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRUse* inArgs, Index argCount, HLSLIntrinsic& out)
{
returnType = m_typeSet->getType(returnType);
if (argCount <= 8)
{
IRType* argTypes[8];
for (Index i = 0; i < argCount; ++i)
{
auto operand = inArgs[i].get();
argTypes[i] = m_typeSet->getType(operand->getDataType());
}
_calcIntrinsic(op, returnType, argTypes, argCount, out);
}
else
{
List<IRType*> argTypes;
argTypes.setCount(argCount);
for (Index i = 0; i < argCount; ++i)
{
auto operand = inArgs[i].get();
argTypes[i] = m_typeSet->getType(operand->getDataType());
}
_calcIntrinsic(op, returnType, argTypes.getBuffer(), argCount, out);
}
}
HLSLIntrinsic* HLSLIntrinsicSet::add(IRInst* inst)
{
HLSLIntrinsic intrinsic;
if (SLANG_SUCCEEDED(makeIntrinsic(inst, intrinsic)))
{
return add(intrinsic);
}
return nullptr;
}
SlangResult HLSLIntrinsicSet::makeIntrinsic(IRInst* inst, HLSLIntrinsic& out)
{
// Mark as invalid...
out.op = Op::Invalid;
{
// See if we can just directly convert
Op op = HLSLIntrinsicOpLookup::getOpForIROp(inst->getOp());
// HACK: some cases we want to stop handling via the synthesis
// path, but only for vector and matrix types (not scalars).
//
switch( op )
{
default: break;
case Op::AsFloat:
case Op::AsInt:
case Op::AsUInt:
// Note: the `any()`/`all()` case can't be handled via a stdlib definition
// right now because `bool` vectors map to `int` vectors on the CUDA
// path, so that the generated `geAt` operation is incorrect.
//
// case Op::Any:
// case Op::All:
{
IRType* srcType = inst->getOperand(0)->getDataType();
switch( srcType->getOp() )
{
default:
break;
case kIROp_VectorType:
case kIROp_MatrixType:
return SLANG_FAIL;
}
}
break;
}
if (op != Op::Invalid)
{
calcIntrinsic(op, inst, inst->getOperandCount(), out);
return SLANG_OK;
}
}
// All the special cases
switch (inst->getOp())
{
case kIROp_constructVectorFromScalar:
{
SLANG_ASSERT(inst->getOperandCount() == 1);
calcIntrinsic(Op::ConstructFromScalar, inst, 1, out);
return SLANG_OK;
}
case kIROp_Construct:
{
IRType* dstType = inst->getDataType();
IRType* srcType = inst->getOperand(0)->getDataType();
if ((dstType->getOp() == kIROp_VectorType || dstType->getOp() == kIROp_MatrixType) &&
inst->getOperandCount() == 1)
{
if (as<IRBasicType>(srcType))
{
calcIntrinsic(Op::ConstructFromScalar, inst, out);
}
else
{
SLANG_ASSERT(m_typeSet->getType(dstType) != m_typeSet->getType(srcType));
// If it's constructed from a type conversion
calcIntrinsic(Op::ConstructConvert, inst, out);
}
return SLANG_OK;
}
else
{
// If we are constructing a basic type, we don't need an Op::Init
if (!IRBasicType::isaImpl(dstType->getOp()))
{
// Emit the 'init' intrinsic
calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out);
return SLANG_OK;
}
}
return SLANG_FAIL;
}
case kIROp_makeVector:
{
if (inst->getOperandCount() == 1 && as<IRBasicType>(inst->getOperand(0)->getDataType()))
{
// This is make from scalar
calcIntrinsic(Op::ConstructFromScalar, inst, out);
}
else
{
calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out);
}
return SLANG_OK;
}
case kIROp_MakeMatrix:
{
// We only emit as if it has one operand, but we can tell how many it actually has from the return type
calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out);
return SLANG_OK;
}
case kIROp_swizzle:
{
// We don't need to add swizzle function, but we do output the need for some other functions
// For C++ we don't need to emit a swizzle function
// For C we need a construction function
auto swizzleInst = static_cast<IRSwizzle*>(inst);
IRInst* baseInst = swizzleInst->getBase();
IRType* baseType = baseInst->getDataType();
// If we are swizzling from a built in type,
if (as<IRBasicType>(baseType))
{
// We can swizzle a scalar type to be a vector, or just a scalar
IRType* dstType = swizzleInst->getDataType();
if (!as<IRBasicType>(dstType))
{
// If it's a scalar make sure we have construct from scalar, because we will want to use that
SLANG_ASSERT(dstType->getOp() == kIROp_VectorType);
IRType* argTypes[] = { baseType };
calcIntrinsic(Op::ConstructFromScalar, inst->getDataType(), argTypes, 1, out);
return SLANG_OK;
}
}
else
{
const Index elementCount = Index(swizzleInst->getElementCount());
if (elementCount >= 1)
{
// Will need to generate a swizzle method
calcIntrinsic(Op::Swizzle, inst, out);
return SLANG_OK;
}
}
break;
}
case kIROp_getElement:
{
IRInst* target = inst->getOperand(0);
IRType* targetType = target->getDataType();
if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType)
{
// Specially handle this
calcIntrinsic(Op::GetAt, inst, out);
return SLANG_OK;
}
break;
}
case kIROp_getElementPtr:
{
IRInst* target = inst->getOperand(0);
IRType* targetType = target->getDataType();
if (auto ptrType = as<IRPtrType>(targetType))
{
targetType = as<IRType>(ptrType->getOperand(0));
if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType)
{
// Specially handle this
calcIntrinsic(Op::GetAt, inst, out);
return SLANG_OK;
}
}
break;
}
case kIROp_Call:
{
IRCall* callInst = (IRCall*)inst;
auto funcValue = callInst->getCallee();
const Op op = m_opLookup->getOpFromTargetDecoration(funcValue);
if (op != Op::Invalid)
{
calcIntrinsic(op, inst->getDataType(), callInst->getArgs(), callInst->getArgCount(), out);
return SLANG_OK;
}
break;
}
default: break;
}
return SLANG_FAIL;
}
void HLSLIntrinsicSet::getIntrinsics(List<const HLSLIntrinsic*>& out) const
{
for (auto& intrinsic : m_intrinsicsList)
{
out.add(intrinsic);
}
}
HLSLIntrinsic* HLSLIntrinsicSet::add(const HLSLIntrinsic& intrinsic)
{
// Make sure it's valid(!)
SLANG_ASSERT(intrinsic.op != Op::Invalid);
HLSLIntrinsic* copy = (HLSLIntrinsic*)m_intrinsicFreeList.allocate();
*copy = intrinsic;
HLSLIntrinsicRef ref(copy);
HLSLIntrinsic** found = m_intrinsicsDict.TryGetValueOrAdd(ref, copy);
if (found)
{
// If we have found an intrinsic, we can free the copy
m_intrinsicFreeList.deallocate(copy);
return *found;
}
// If we are adding an intrinsic for the first time,
// it should be added to the deduplicated list
m_intrinsicsList.add(copy);
return copy;
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicOpLookup !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
HLSLIntrinsicOpLookup::HLSLIntrinsicOpLookup():
m_slicePool(StringSlicePool::Style::Default)
{
// Add all the operations with names (not ops like -, / etc) to the lookup map
for (int i = 0; i < SLANG_COUNT_OF(HLSLIntrinsic::s_operationInfos); ++i)
{
const auto& info = HLSLIntrinsic::getInfo(Op(i));
UnownedStringSlice slice = info.funcName;
if (slice.getLength() > 0 && slice[0] >= 'a' && slice[0] <= 'z')
{
auto handle = m_slicePool.add(slice);
Index index = Index(handle);
// Make sure there is space
if (index >= m_sliceToOpMap.getCount())
{
Index oldSize = m_sliceToOpMap.getCount();
m_sliceToOpMap.setCount(index + 1);
for (Index j = oldSize; j < index; j++)
{
m_sliceToOpMap[j] = Op::Invalid;
}
}
m_sliceToOpMap[index] = Op(i);
}
}
}
HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpByName(const UnownedStringSlice& slice)
{
const Index index = m_slicePool.findIndex(slice);
return (index >= 0 && index < m_sliceToOpMap.getCount()) ? m_sliceToOpMap[index] : Op::Invalid;
}
static IRInst* _getSpecializedValue(IRSpecialize* specInst)
{
auto base = specInst->getBase();
auto baseGeneric = as<IRGeneric>(base);
if (!baseGeneric)
return base;
auto lastBlock = baseGeneric->getLastBlock();
if (!lastBlock)
return base;
auto returnInst = as<IRReturnVal>(lastBlock->getTerminator());
if (!returnInst)
return base;
return returnInst->getVal();
}
HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpFromTargetDecoration(IRInst* inInst)
{
// An intrinsic generic function will be invoked through a `specialize` instruction,
// so the callee won't directly be the thing that is decorated. We will look up
// through specializations until we can see the actual thing being called.
//
IRInst* inst = inInst;
while (auto specInst = as<IRSpecialize>(inst))
{
inst = _getSpecializedValue(specInst);
// If `getSpecializedValue` can't find the result value
// of the generic being specialized, then it returns
// the original instruction. This would be a disaster
// for use because this loop would go on forever.
//
// This case should never happen if the stdlib is well-formed
// and the compiler is doing its job right.
//
SLANG_ASSERT(inst != specInst);
}
// We are just looking for the original name so we can match against it
for (auto dd : inst->getDecorations())
{
if (auto decor = as<IRTargetIntrinsicDecoration>(dd))
{
// TODO(JS): Should confirm that we'll always have this entry - which we need for lookups to work (we need the name
// not a targets transformation)
//
// It turns out that addCatchAllIntrinsicDecorationIfNeeded will add a target intrinsic with the
// original HLSL name, which has an empty `CapabilitySet`.
//
// It's not 100% clear this covers all the cases, but for now lets go with that
if (decor->getTargetCaps().isEmpty())
{
Op op = getOpByName(decor->getDefinition());
if (op != Op::Invalid)
{
return op;
}
}
}
}
return Op::Invalid;
}
HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_Call:
{
return getOpFromTargetDecoration(inst);
}
default: break;
}
return getOpForIROp(inst->getOp());
}
/* static */HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IROp op)
{
switch (op)
{
case kIROp_Add: return Op::Add;
case kIROp_Mul: return Op::Mul;
case kIROp_Sub: return Op::Sub;
case kIROp_Div: return Op::Div;
case kIROp_Lsh: return Op::Lsh;
case kIROp_Rsh: return Op::Rsh;
case kIROp_IRem: return Op::IRem;
case kIROp_FRem: return Op::FRem;
case kIROp_Eql: return Op::Eql;
case kIROp_Neq: return Op::Neq;
case kIROp_Greater: return Op::Greater;
case kIROp_Less: return Op::Less;
case kIROp_Geq: return Op::Geq;
case kIROp_Leq: return Op::Leq;
case kIROp_BitAnd: return Op::BitAnd;
case kIROp_BitXor: return Op::BitXor;
case kIROp_BitOr: return Op::BitOr;
case kIROp_And: return Op::And;
case kIROp_Or: return Op::Or;
case kIROp_Neg: return Op::Neg;
case kIROp_Not: return Op::Not;
case kIROp_BitNot: return Op::BitNot;
case kIROp_constructVectorFromScalar: return Op::ConstructFromScalar;
default: return Op::Invalid;
}
}
}