https://github.com/shader-slang/slang
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-any-value-marshalling.cpp
#include "slang-ir-any-value-marshalling.h"
#include "../core/slang-math.h"
#include "slang-ir-generics-lowering-context.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
namespace Slang
{
// This is a subpass of generics lowering IR transformation.
// This pass generates packing/unpacking functions for `AnyValue`s,
// and replaces all `IRPackAnyValue` and `IRUnpackAnyValue` with calls to these
// functions.
struct AnyValueMarshallingContext
{
SharedGenericsLoweringContext* sharedContext;
// Stores information about generated `AnyValue` struct types.
struct AnyValueTypeInfo : RefObject
{
IRType* type; // The generated IR value for the `AnyValue<N>` struct type.
List<IRStructKey*> fieldKeys; // `IRStructKey`s for the fields of the generated type.
};
Dictionary<IRIntegerValue, RefPtr<AnyValueTypeInfo>> generatedAnyValueTypes;
struct MarshallingFunctionKey
{
IRType* originalType;
IRIntegerValue anyValueSize;
bool operator ==(MarshallingFunctionKey other) const
{
return originalType == other.originalType && anyValueSize == other.anyValueSize;
}
HashCode getHashCode() const
{
return combineHash(Slang::getHashCode(originalType), Slang::getHashCode(anyValueSize));
}
};
struct MarshallingFunctionSet
{
IRFunc* packFunc;
IRFunc* unpackFunc;
};
// Stores the generated packing/unpacking functions for lookup.
Dictionary<MarshallingFunctionKey, MarshallingFunctionSet> mapTypeMarshalingFunctions;
AnyValueTypeInfo* ensureAnyValueType(IRAnyValueType* type)
{
auto size = getIntVal(type->getSize());
if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size))
return typeInfo->Ptr();
RefPtr<AnyValueTypeInfo> info = new AnyValueTypeInfo();
IRBuilder builder(sharedContext->module);
builder.setInsertBefore(type);
auto structType = builder.createStructType();
info->type = structType;
StringBuilder nameSb;
nameSb << "AnyValue" << size;
builder.addExportDecoration(structType, nameSb.getUnownedSlice());
auto fieldCount = (size + sizeof(uint32_t) - 1) / sizeof(uint32_t);
for (decltype(fieldCount) i = 0; i < fieldCount; i++)
{
auto key = builder.createStructKey();
nameSb.clear();
nameSb << "field" << i;
builder.addNameHintDecoration(key, nameSb.getUnownedSlice());
nameSb << "_anyVal" << size;
builder.addExportDecoration(key, nameSb.getUnownedSlice());
builder.createStructField(structType, key, builder.getUIntType());
info->fieldKeys.add(key);
}
generatedAnyValueTypes[size] = info;
return info.Ptr();
}
struct TypeMarshallingContext
{
AnyValueTypeInfo* anyValInfo;
uint32_t fieldOffset;
uint32_t intraFieldOffset;
IRType* uintPtrType;
IRInst* anyValueVar;
// Defines what to do with basic typed data elements.
virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;
// Defines what to do with resource handle elements.
virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;
void ensureOffsetAt4ByteBoundary()
{
if (intraFieldOffset)
{
fieldOffset++;
intraFieldOffset = 0;
}
}
void ensureOffsetAt2ByteBoundary()
{
if (intraFieldOffset == 0)
return;
if (intraFieldOffset <= 2)
{
intraFieldOffset = 2;
return;
}
fieldOffset++;
intraFieldOffset = 0;
return;
}
void advanceOffset(uint32_t bytes)
{
intraFieldOffset += bytes;
fieldOffset += intraFieldOffset / 4;
intraFieldOffset = intraFieldOffset % 4;
}
};
void emitMarshallingCode(
IRBuilder* builder,
TypeMarshallingContext* context,
IRInst* concreteTypedVar)
{
auto dataType = cast<IRPtrTypeBase>(concreteTypedVar->getDataType())->getValueType();
switch (dataType->getOp())
{
case kIROp_IntType:
case kIROp_FloatType:
case kIROp_UIntType:
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
case kIROp_Int8Type:
case kIROp_Int16Type:
case kIROp_UInt8Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
case kIROp_BoolType:
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
context->marshalBasicType(builder, dataType, concreteTypedVar);
break;
case kIROp_VectorType:
{
auto vectorType = static_cast<IRVectorType*>(dataType);
auto elementType = vectorType->getElementType();
auto elementCount = getIntVal(vectorType->getElementCount());
auto elementPtrType = builder->getPtrType(elementType);
for (IRIntegerValue i = 0; i < elementCount; i++)
{
auto elementAddr = builder->emitElementAddress(
elementPtrType,
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
emitMarshallingCode(builder, context, elementAddr);
}
break;
}
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(dataType);
auto elementType = matrixType->getElementType();
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
auto rowVecType = builder->getVectorType(elementType, matrixType->getRowCount());
for (IRIntegerValue i = 0; i < colCount; i++)
{
auto col = builder->emitElementAddress(
builder->getPtrType(rowVecType),
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
for (IRIntegerValue j = 0; j < rowCount; j++)
{
auto element = builder->emitElementAddress(
builder->getPtrType(elementType),
col,
builder->getIntValue(builder->getIntType(), j));
emitMarshallingCode(builder, context, element);
}
}
break;
}
case kIROp_StructType:
{
auto structType = cast<IRStructType>(dataType);
for (auto field : structType->getFields())
{
auto fieldAddr = builder->emitFieldAddress(
builder->getPtrType(field->getFieldType()),
concreteTypedVar,
field->getKey());
emitMarshallingCode(builder, context, fieldAddr);
}
break;
}
case kIROp_ArrayType:
{
auto arrayType = cast<IRArrayType>(dataType);
auto elementPtrType = builder->getPtrType(arrayType->getElementType());
for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++)
{
auto fieldAddr = builder->emitElementAddress(
elementPtrType,
concreteTypedVar,
builder->getIntValue(builder->getIntType(), i));
emitMarshallingCode(builder, context, fieldAddr);
}
break;
}
case kIROp_AnyValueType:
{
auto anyValType = cast<IRAnyValueType>(dataType);
auto info = ensureAnyValueType(anyValType);
for (auto field : info->fieldKeys)
{
auto fieldAddr = builder->emitFieldAddress(
builder->getPtrType(builder->getUIntType()),
concreteTypedVar,
field);
emitMarshallingCode(builder, context, fieldAddr);
}
break;
}
default:
if (as<IRTextureTypeBase>(dataType) || as<IRSamplerStateTypeBase>(dataType))
{
context->marshalResourceHandle(builder, dataType, concreteTypedVar);
return;
}
SLANG_UNIMPLEMENTED_X("Unimplemented type packing");
break;
}
}
struct TypePackingContext : TypeMarshallingContext
{
virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
{
switch (dataType->getOp())
{
case kIROp_IntType:
case kIROp_FloatType:
#if SLANG_PTR_IS_32
case kIROp_IntPtrType:
#endif
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcVal = builder->emitLoad(concreteVar);
auto dstVal = builder->emitBitCast(builder->getUIntType(), srcVal);
auto dstAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
builder->emitStore(dstAddr, dstVal);
}
advanceOffset(4);
break;
}
case kIROp_BoolType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcVal = builder->emitLoad(concreteVar);
IRInst* args[] = {srcVal, builder->getIntValue(builder->getUIntType(), 1), builder->getIntValue(builder->getUIntType(), 0) };
auto dstVal = builder->emitIntrinsicInst(builder->getUIntType(), kIROp_Select, 3, args);
auto dstAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
builder->emitStore(dstAddr, dstVal);
}
advanceOffset(4);
break;
}
case kIROp_UIntType:
#if SLANG_PTR_IS_32
case kIROp_UIntPtrType:
#endif
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcVal = builder->emitLoad(concreteVar);
auto dstAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
builder->emitStore(dstAddr, srcVal);
}
advanceOffset(4);
break;
}
case kIROp_Int16Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
{
ensureOffsetAt2ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcVal = builder->emitLoad(concreteVar);
if (dataType->getOp() == kIROp_HalfType)
{
srcVal = builder->emitBitCast(builder->getType(kIROp_UInt16Type), srcVal);
}
srcVal = builder->emitCast(builder->getType(kIROp_UIntType), srcVal);
auto dstAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
auto dstVal = builder->emitLoad(dstAddr);
if (intraFieldOffset == 0)
{
dstVal = builder->emitBitAnd(
dstVal->getFullType(), dstVal,
builder->getIntValue(builder->getUIntType(), 0xFFFF0000));
}
else
{
srcVal = builder->emitShl(
srcVal->getFullType(), srcVal,
builder->getIntValue(builder->getUIntType(), 16));
dstVal = builder->emitBitAnd(
dstVal->getFullType(), dstVal,
builder->getIntValue(builder->getUIntType(), 0xFFFF));
}
dstVal = builder->emitBitOr(dstVal->getFullType(), dstVal, srcVal);
builder->emitStore(dstAddr, dstVal);
}
advanceOffset(2);
break;
}
case kIROp_Int8Type:
case kIROp_UInt8Type:
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
#if SLANG_PTR_IS_64
case kIROp_UIntPtrType:
case kIROp_IntPtrType:
#endif
SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
break;
default:
SLANG_UNREACHABLE("unknown basic type");
}
}
virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
{
SLANG_UNUSED(dataType);
ensureOffsetAt4ByteBoundary();
if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcVal = builder->emitLoad(concreteVar);
auto uint64Val = builder->emitBitCast(builder->getUInt64Type(), srcVal);
auto lowBits = builder->emitCast(builder->getUIntType(), uint64Val);
auto shiftedBits = builder->emitShr(
builder->getUInt64Type(),
uint64Val,
builder->getIntValue(builder->getIntType(), 32));
auto highBits = builder->emitBitCast(builder->getUIntType(), shiftedBits);
auto dstAddr1 = builder->emitFieldAddress(
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
builder->emitStore(dstAddr1, lowBits);
auto dstAddr2 = builder->emitFieldAddress(
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
builder->emitStore(dstAddr2, highBits);
advanceOffset(8);
}
}
};
IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType)
{
IRBuilder builder(sharedContext->module);
builder.setInsertBefore(type);
auto anyValInfo = ensureAnyValueType(anyValueType);
auto func = builder.createFunc();
StringBuilder nameSb;
nameSb << "packAnyValue" << getIntVal(anyValueType->getSize());
builder.addNameHintDecoration(func, nameSb.getUnownedSlice());
// Currently we don't add linkage to the generated func, since we
// do not have a way to compute mangled names from an IR entity.
// This will leads to duplicate packing functions in linked code
// but there won't be correctness issues.
auto funcType = builder.getFuncType(1, &type, anyValInfo->type);
func->setFullType(funcType);
builder.setInsertInto(func);
builder.emitBlock();
auto param = builder.emitParam(type);
auto concreteTypedVar = builder.emitVar(type);
builder.emitStore(concreteTypedVar, param);
auto resultVar = builder.emitVar(anyValInfo->type);
// Initialize fields to 0 to prevent downstream compiler error.
for (uint32_t offset = 0; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++)
{
auto fieldAddr = builder.emitFieldAddress(
builder.getPtrType(builder.getUIntType()),
resultVar,
anyValInfo->fieldKeys[offset]
);
builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0));
}
TypePackingContext context;
context.anyValInfo = anyValInfo;
context.fieldOffset = context.intraFieldOffset = 0;
context.uintPtrType = builder.getPtrType(builder.getUIntType());
context.anyValueVar = resultVar;
emitMarshallingCode(&builder, &context, concreteTypedVar);
auto load = builder.emitLoad(resultVar);
builder.emitReturn(load);
return func;
}
struct TypeUnpackingContext : TypeMarshallingContext
{
virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
{
switch (dataType->getOp())
{
case kIROp_IntType:
case kIROp_FloatType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
auto srcVal = builder->emitLoad(srcAddr);
srcVal = builder->emitBitCast(dataType, srcVal);
builder->emitStore(concreteVar, srcVal);
}
advanceOffset(4);
break;
}
case kIROp_BoolType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
auto srcVal = builder->emitLoad(srcAddr);
srcVal = builder->emitNeq(srcVal, builder->getIntValue(builder->getUIntType(), 0));
builder->emitStore(concreteVar, srcVal);
}
advanceOffset(4);
break;
}
case kIROp_UIntType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
auto srcVal = builder->emitLoad(srcAddr);
builder->emitStore(concreteVar, srcVal);
}
advanceOffset(4);
break;
}
case kIROp_Int16Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
{
ensureOffsetAt2ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcAddr = builder->emitFieldAddress(
uintPtrType,
anyValueVar,
anyValInfo->fieldKeys[fieldOffset]);
auto srcVal = builder->emitLoad(srcAddr);
if (intraFieldOffset == 0)
{
srcVal = builder->emitBitAnd(
srcVal->getFullType(), srcVal,
builder->getIntValue(builder->getUIntType(), 0xFFFF));
}
else
{
srcVal = builder->emitShr(
srcVal->getFullType(), srcVal,
builder->getIntValue(builder->getUIntType(), 16));
}
if (dataType->getOp() == kIROp_Int16Type)
{
srcVal = builder->emitCast(builder->getType(kIROp_Int16Type), srcVal);
}
else
{
srcVal = builder->emitCast(builder->getType(kIROp_UInt16Type), srcVal);
}
if (dataType->getOp() == kIROp_HalfType)
{
srcVal = builder->emitBitCast(dataType, srcVal);
}
builder->emitStore(concreteVar, srcVal);
}
advanceOffset(2);
break;
}
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
case kIROp_Int8Type:
case kIROp_UInt8Type:
SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
break;
default:
SLANG_UNREACHABLE("unknown basic type");
}
}
virtual void marshalResourceHandle(
IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
{
auto srcAddr = builder->emitFieldAddress(
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
auto lowBits = builder->emitLoad(srcAddr);
auto srcAddr1 = builder->emitFieldAddress(
uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
auto highBits = builder->emitLoad(srcAddr1);
auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
combinedBits = builder->emitBitCast(dataType, combinedBits);
builder->emitStore(concreteVar, combinedBits);
advanceOffset(8);
}
}
};
IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType)
{
IRBuilder builder(sharedContext->module);
builder.setInsertBefore(type);
auto anyValInfo = ensureAnyValueType(anyValueType);
auto func = builder.createFunc();
StringBuilder nameSb;
nameSb << "unpackAnyValue" << getIntVal(anyValueType->getSize());
builder.addNameHintDecoration(func, nameSb.getUnownedSlice());
auto funcType = builder.getFuncType(1, &anyValInfo->type, type);
func->setFullType(funcType);
builder.setInsertInto(func);
builder.emitBlock();
auto param = builder.emitParam(anyValInfo->type);
auto anyValueVar = builder.emitVar(anyValInfo->type);
builder.emitStore(anyValueVar, param);
auto resultVar = builder.emitVar(type);
TypeUnpackingContext context;
context.anyValInfo = anyValInfo;
context.fieldOffset = context.intraFieldOffset = 0;
context.uintPtrType = builder.getPtrType(builder.getUIntType());
context.anyValueVar = anyValueVar;
emitMarshallingCode(&builder, &context, resultVar);
auto load = builder.emitLoad(resultVar);
builder.emitReturn(load);
return func;
}
// Ensures the marshalling functions between `type` and `anyValueType` are already generated.
// Returns the generated marshalling functions.
MarshallingFunctionSet ensureMarshallingFunc(IRType* type, IRAnyValueType* anyValueType)
{
auto size = getIntVal(anyValueType->getSize());
MarshallingFunctionKey key;
key.originalType = type;
key.anyValueSize = size;
MarshallingFunctionSet funcSet;
if (mapTypeMarshalingFunctions.tryGetValue(key, funcSet))
return funcSet;
funcSet.packFunc = generatePackingFunc(type, anyValueType);
funcSet.unpackFunc = generateUnpackingFunc(type, anyValueType);
mapTypeMarshalingFunctions[key] = funcSet;
return funcSet;
}
void processPackInst(IRPackAnyValue* packInst)
{
auto operand = packInst->getValue();
auto func = ensureMarshallingFunc(
operand->getDataType(),
cast<IRAnyValueType>(packInst->getDataType()));
IRBuilder builderStorage(sharedContext->module);
auto builder = &builderStorage;
builder->setInsertBefore(packInst);
auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand);
packInst->replaceUsesWith(callInst);
packInst->removeAndDeallocate();
}
void processUnpackInst(IRUnpackAnyValue* unpackInst)
{
auto operand = unpackInst->getValue();
auto func = ensureMarshallingFunc(
unpackInst->getDataType(),
cast<IRAnyValueType>(operand->getDataType()));
IRBuilder builderStorage(sharedContext->module);
auto builder = &builderStorage;
builder->setInsertBefore(unpackInst);
auto callInst = builder->emitCallInst(unpackInst->getDataType(), func.unpackFunc, 1, &operand);
unpackInst->replaceUsesWith(callInst);
unpackInst->removeAndDeallocate();
}
void processAnyValueType(IRAnyValueType* type)
{
auto info = ensureAnyValueType(type);
type->replaceUsesWith(info->type);
}
void processInst(IRInst* inst)
{
if (auto packInst = as<IRPackAnyValue>(inst))
{
processPackInst(packInst);
}
else if (auto unpackInst = as<IRUnpackAnyValue>(inst))
{
processUnpackInst(unpackInst);
}
}
void processModule()
{
// We start by initializing our shared IR building state,
// since we will re-use that state for any code we
// generate along the way.
//
sharedContext->addToWorkList(sharedContext->module->getModuleInst());
while (sharedContext->workList.getCount() != 0)
{
IRInst* inst = sharedContext->workList.getLast();
sharedContext->workList.removeLast();
sharedContext->workListSet.remove(inst);
processInst(inst);
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
{
sharedContext->addToWorkList(child);
}
}
// Finally, replace all `AnyValueType` with the actual struct type that implements it.
for (auto inst : sharedContext->module->getModuleInst()->getChildren())
{
if (auto anyValueType = as<IRAnyValueType>(inst))
processAnyValueType(anyValueType);
}
sharedContext->mapInterfaceRequirementKeyValue.clear();
}
};
void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext)
{
AnyValueMarshallingContext context;
context.sharedContext = sharedContext;
context.processModule();
}
SlangInt alignUp(SlangInt x, SlangInt alignment)
{
return (x + alignment - 1) / alignment * alignment;
}
SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
{
switch (type->getOp())
{
case kIROp_IntType:
case kIROp_FloatType:
case kIROp_UIntType:
case kIROp_BoolType:
return alignUp(offset, 4) + 4;
case kIROp_UInt64Type:
case kIROp_Int64Type:
case kIROp_DoubleType:
return -1;
case kIROp_Int16Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
return alignUp(offset, 2) + 2;
case kIROp_UInt8Type:
case kIROp_Int8Type:
return -1;
case kIROp_VectorType:
{
auto vectorType = static_cast<IRVectorType*>(type);
auto elementType = vectorType->getElementType();
auto elementCount = getIntVal(vectorType->getElementCount());
for (IRIntegerValue i = 0; i < elementCount; i++)
{
offset = _getAnyValueSizeRaw(elementType, offset);
if (offset < 0) return offset;
}
return offset;
}
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(type);
auto elementType = matrixType->getElementType();
auto colCount = getIntVal(matrixType->getColumnCount());
auto rowCount = getIntVal(matrixType->getRowCount());
for (IRIntegerValue i = 0; i < colCount; i++)
{
for (IRIntegerValue j = 0; j < rowCount; j++)
{
offset = _getAnyValueSizeRaw(elementType, offset);
if (offset < 0) return offset;
}
}
return offset;
}
case kIROp_StructType:
{
auto structType = cast<IRStructType>(type);
for (auto field : structType->getFields())
{
offset = _getAnyValueSizeRaw(field->getFieldType(), offset);
if (offset < 0) return offset;
}
return offset;
}
case kIROp_ArrayType:
{
auto arrayType = cast<IRArrayType>(type);
for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++)
{
offset = _getAnyValueSizeRaw(arrayType->getElementType(), offset);
if (offset < 0) return offset;
}
return offset;
}
case kIROp_AnyValueType:
{
auto anyValueType = cast<IRAnyValueType>(type);
return alignUp(offset, 4) + (SlangInt)getIntVal(anyValueType->getSize());
}
case kIROp_TupleType:
{
auto tupleType = cast<IRTupleType>(type);
for (UInt i = 0; i < tupleType->getOperandCount(); i++)
{
auto elementType = tupleType->getOperand(i);
offset = _getAnyValueSizeRaw((IRType*)elementType, offset);
if (offset < 0) return offset;
}
return offset;
}
case kIROp_WitnessTableType:
case kIROp_WitnessTableIDType:
case kIROp_RTTIHandleType:
{
return alignUp(offset, 4) + kRTTIHandleSize;
}
case kIROp_InterfaceType:
{
auto interfaceType = cast<IRInterfaceType>(type);
auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
size += kRTTIHeaderSize;
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
case kIROp_AssociatedType:
{
auto associatedType = cast<IRAssociatedType>(type);
SlangInt maxSize = 0;
for (UInt i = 0; i < associatedType->getOperandCount(); i++)
maxSize = Math::Max(maxSize, _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset));
return maxSize;
}
case kIROp_ThisType:
{
auto thisType = cast<IRThisType>(type);
auto interfaceType = thisType->getConstraintType();
auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
case kIROp_ExtractExistentialType:
{
auto existentialValue = type->getOperand(0);
auto interfaceType = cast<IRInterfaceType>(existentialValue->getDataType());
auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
}
case kIROp_LookupWitness:
{
auto witnessTableVal = type->getOperand(0);
auto key = type->getOperand(1);
IRType* assocType = nullptr;
if (auto witnessTableType = as<IRWitnessTableTypeBase>(witnessTableVal->getDataType()))
{
auto interfaceType = as<IRInterfaceType>(witnessTableType->getConformanceType());
// Walk through interface operands to find a match, the result should be an
// associated type entry.
//
for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++)
{
auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii));
if (entry->getRequirementKey() == key &&
as<IRAssociatedType>(entry->getRequirementVal()))
{
assocType = (IRType*)entry->getRequirementVal();
break;
}
}
}
if (!assocType)
return -1;
IRIntegerValue anyValueSize = kInvalidAnyValueSize;
for (UInt i = 0; i < assocType->getOperandCount(); i++)
{
anyValueSize = Math::Min(
anyValueSize,
SharedGenericsLoweringContext::getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc));
}
if (anyValueSize == kInvalidAnyValueSize)
return -1;
return alignUp(offset, 4) + alignUp((SlangInt)anyValueSize, 4);
}
default:
if (as<IRTextureTypeBase>(type) || as<IRSamplerStateTypeBase>(type))
{
return alignUp(offset, 4) + 8;
}
return -1;
}
}
SlangInt getAnyValueSize(IRType* type)
{
auto rawSize = _getAnyValueSizeRaw(type, 0);
if (rawSize < 0) return rawSize;
return alignUp(rawSize, 4);
}
}