https://github.com/shader-slang/slang
Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-any-value-marshalling.cpp
#include "slang-ir-any-value-marshalling.h"

#include "../core/slang-math.h"
#include "slang-ir-generics-lowering-context.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"

namespace Slang
{
    // This is a subpass of generics lowering IR transformation.
    // This pass generates packing/unpacking functions for `AnyValue`s,
    // and replaces all `IRPackAnyValue` and `IRUnpackAnyValue` with calls to these
    // functions.
    struct AnyValueMarshallingContext
    {
        SharedGenericsLoweringContext* sharedContext;

        // Stores information about generated `AnyValue` struct types.
        struct AnyValueTypeInfo : RefObject
        {
            IRType* type; // The generated IR value for the `AnyValue<N>` struct type.
            List<IRStructKey*> fieldKeys; // `IRStructKey`s for the fields of the generated type.
        };

        Dictionary<IRIntegerValue, RefPtr<AnyValueTypeInfo>> generatedAnyValueTypes;

        struct MarshallingFunctionKey
        {
            IRType* originalType;
            IRIntegerValue anyValueSize;
            bool operator ==(MarshallingFunctionKey other) const
            {
                return originalType == other.originalType && anyValueSize == other.anyValueSize;
            }
            HashCode getHashCode() const
            {
                return combineHash(Slang::getHashCode(originalType), Slang::getHashCode(anyValueSize));
            }
        };

        struct MarshallingFunctionSet
        {
            IRFunc* packFunc;
            IRFunc* unpackFunc;
        };

        // Stores the generated packing/unpacking functions for lookup.
        Dictionary<MarshallingFunctionKey, MarshallingFunctionSet> mapTypeMarshalingFunctions;

        AnyValueTypeInfo* ensureAnyValueType(IRAnyValueType* type)
        {
            auto size = getIntVal(type->getSize());
            if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size))
                return typeInfo->Ptr();
            RefPtr<AnyValueTypeInfo> info = new AnyValueTypeInfo();
            IRBuilder builder(sharedContext->module);
            builder.setInsertBefore(type);
            auto structType = builder.createStructType();
            info->type = structType;
            StringBuilder nameSb;
            nameSb << "AnyValue" << size;
            builder.addExportDecoration(structType, nameSb.getUnownedSlice());
            auto fieldCount = (size + sizeof(uint32_t) - 1) / sizeof(uint32_t);
            for (decltype(fieldCount) i = 0; i < fieldCount; i++)
            {
                auto key = builder.createStructKey();
                nameSb.clear();
                nameSb << "field" << i;
                builder.addNameHintDecoration(key, nameSb.getUnownedSlice());
                nameSb << "_anyVal" << size;
                builder.addExportDecoration(key, nameSb.getUnownedSlice());
                builder.createStructField(structType, key, builder.getUIntType());
                info->fieldKeys.add(key);
            }
            generatedAnyValueTypes[size] = info;
            return info.Ptr();
        }

        struct TypeMarshallingContext
        {
            AnyValueTypeInfo* anyValInfo;
            uint32_t fieldOffset;
            uint32_t intraFieldOffset;
            IRType* uintPtrType;
            IRInst* anyValueVar;
            // Defines what to do with basic typed data elements.
            virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;
            // Defines what to do with resource handle elements.
            virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0;

            void ensureOffsetAt4ByteBoundary()
            {
                if (intraFieldOffset)
                {
                    fieldOffset++;
                    intraFieldOffset = 0;
                }
            }
            void ensureOffsetAt2ByteBoundary()
            {
                if (intraFieldOffset == 0)
                    return;
                if (intraFieldOffset <= 2)
                {
                    intraFieldOffset = 2;
                    return;
                }
                fieldOffset++;
                intraFieldOffset = 0;
                return;
            }
            void advanceOffset(uint32_t bytes)
            {
                intraFieldOffset += bytes;
                fieldOffset += intraFieldOffset / 4;
                intraFieldOffset = intraFieldOffset % 4;
            }
        };

        void emitMarshallingCode(
            IRBuilder* builder,
            TypeMarshallingContext* context,
            IRInst* concreteTypedVar)
        {
            auto dataType = cast<IRPtrTypeBase>(concreteTypedVar->getDataType())->getValueType();
            switch (dataType->getOp())
            {
            case kIROp_IntType:
            case kIROp_FloatType:
            case kIROp_UIntType:
            case kIROp_UInt64Type:
            case kIROp_Int64Type:
            case kIROp_DoubleType:
            case kIROp_Int8Type:
            case kIROp_Int16Type:
            case kIROp_UInt8Type:
            case kIROp_UInt16Type:
            case kIROp_HalfType:
            case kIROp_BoolType:
            case kIROp_IntPtrType:
            case kIROp_UIntPtrType:
                context->marshalBasicType(builder, dataType, concreteTypedVar);
                break;
            case kIROp_VectorType:
            {
                auto vectorType = static_cast<IRVectorType*>(dataType);
                auto elementType = vectorType->getElementType();
                auto elementCount = getIntVal(vectorType->getElementCount());
                auto elementPtrType = builder->getPtrType(elementType);
                for (IRIntegerValue i = 0; i < elementCount; i++)
                {
                    auto elementAddr = builder->emitElementAddress(
                        elementPtrType,
                        concreteTypedVar,
                        builder->getIntValue(builder->getIntType(), i));
                    emitMarshallingCode(builder, context, elementAddr);
                }
                break;
            }
            case kIROp_MatrixType:
            {
                auto matrixType = static_cast<IRMatrixType*>(dataType);
                auto elementType = matrixType->getElementType();
                auto colCount = getIntVal(matrixType->getColumnCount());
                auto rowCount = getIntVal(matrixType->getRowCount());
                auto rowVecType = builder->getVectorType(elementType, matrixType->getRowCount());
                for (IRIntegerValue i = 0; i < colCount; i++)
                {
                    auto col = builder->emitElementAddress(
                        builder->getPtrType(rowVecType),
                        concreteTypedVar,
                        builder->getIntValue(builder->getIntType(), i));
                    for (IRIntegerValue j = 0; j < rowCount; j++)
                    {
                        auto element = builder->emitElementAddress(
                            builder->getPtrType(elementType),
                            col,
                            builder->getIntValue(builder->getIntType(), j));
                        emitMarshallingCode(builder, context, element);
                    }
                }
                break;
            }
            case kIROp_StructType:
            {
                auto structType = cast<IRStructType>(dataType);
                for (auto field : structType->getFields())
                {
                    auto fieldAddr = builder->emitFieldAddress(
                        builder->getPtrType(field->getFieldType()),
                        concreteTypedVar,
                        field->getKey());
                    emitMarshallingCode(builder, context, fieldAddr);
                }
                break;
            }
            case kIROp_ArrayType:
            {
                auto arrayType = cast<IRArrayType>(dataType);
                auto elementPtrType = builder->getPtrType(arrayType->getElementType());
                for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++)
                {
                    auto fieldAddr = builder->emitElementAddress(
                        elementPtrType,
                        concreteTypedVar,
                        builder->getIntValue(builder->getIntType(), i));
                    emitMarshallingCode(builder, context, fieldAddr);
                }
                break;
            }
            case kIROp_AnyValueType:
            {
                auto anyValType = cast<IRAnyValueType>(dataType);
                auto info = ensureAnyValueType(anyValType);
                for (auto field : info->fieldKeys)
                {
                    auto fieldAddr = builder->emitFieldAddress(
                        builder->getPtrType(builder->getUIntType()),
                        concreteTypedVar,
                        field);
                    emitMarshallingCode(builder, context, fieldAddr);
                }
                break;
            }
            default:
                if (as<IRTextureTypeBase>(dataType) || as<IRSamplerStateTypeBase>(dataType))
                {
                    context->marshalResourceHandle(builder, dataType, concreteTypedVar);
                    return;
                }
                SLANG_UNIMPLEMENTED_X("Unimplemented type packing");
                break;
            }
        }

        struct TypePackingContext : TypeMarshallingContext
        {
            virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
            {
                switch (dataType->getOp())
                {
                case kIROp_IntType:
                case kIROp_FloatType:
#if SLANG_PTR_IS_32
                case kIROp_IntPtrType:
#endif
                {
                    ensureOffsetAt4ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcVal = builder->emitLoad(concreteVar);
                        auto dstVal = builder->emitBitCast(builder->getUIntType(), srcVal);
                        auto dstAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        builder->emitStore(dstAddr, dstVal);
                    }
                    advanceOffset(4);
                    break;
                }
                case kIROp_BoolType:
                {
                    ensureOffsetAt4ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcVal = builder->emitLoad(concreteVar);
                        IRInst* args[] = {srcVal, builder->getIntValue(builder->getUIntType(), 1), builder->getIntValue(builder->getUIntType(), 0) };
                        auto dstVal = builder->emitIntrinsicInst(builder->getUIntType(), kIROp_Select, 3, args);
                        auto dstAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        builder->emitStore(dstAddr, dstVal);
                    }
                    advanceOffset(4);
                    break;
                }
                case kIROp_UIntType:
#if SLANG_PTR_IS_32
                case kIROp_UIntPtrType:
#endif
                {
                    ensureOffsetAt4ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcVal = builder->emitLoad(concreteVar);
                        auto dstAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        builder->emitStore(dstAddr, srcVal);
                    }
                    advanceOffset(4);
                    break;
                }
                case kIROp_Int16Type:
                case kIROp_UInt16Type:
                case kIROp_HalfType:
                {
                    ensureOffsetAt2ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcVal = builder->emitLoad(concreteVar);
                        if (dataType->getOp() == kIROp_HalfType)
                        {
                            srcVal = builder->emitBitCast(builder->getType(kIROp_UInt16Type), srcVal);
                        }
                        srcVal = builder->emitCast(builder->getType(kIROp_UIntType), srcVal);
                        auto dstAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        auto dstVal = builder->emitLoad(dstAddr);
                        if (intraFieldOffset == 0)
                        {
                            dstVal = builder->emitBitAnd(
                                dstVal->getFullType(), dstVal,
                                builder->getIntValue(builder->getUIntType(), 0xFFFF0000));
                        }
                        else
                        {
                            srcVal = builder->emitShl(
                                srcVal->getFullType(), srcVal,
                                builder->getIntValue(builder->getUIntType(), 16));
                            dstVal = builder->emitBitAnd(
                                dstVal->getFullType(), dstVal,
                                builder->getIntValue(builder->getUIntType(), 0xFFFF));
                        }
                        dstVal = builder->emitBitOr(dstVal->getFullType(), dstVal, srcVal);
                        builder->emitStore(dstAddr, dstVal);
                    }
                    advanceOffset(2);
                    break;
                }
                case kIROp_Int8Type:
                case kIROp_UInt8Type:
                case kIROp_UInt64Type:
                case kIROp_Int64Type:
                case kIROp_DoubleType:
#if SLANG_PTR_IS_64
                case kIROp_UIntPtrType:
                case kIROp_IntPtrType:
#endif
                    SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
                    break;
                default:
                    SLANG_UNREACHABLE("unknown basic type");
                }
            }

            virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
            {
                SLANG_UNUSED(dataType);
                ensureOffsetAt4ByteBoundary();
                if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                {
                    auto srcVal = builder->emitLoad(concreteVar);
                    auto uint64Val = builder->emitBitCast(builder->getUInt64Type(), srcVal);
                    auto lowBits = builder->emitCast(builder->getUIntType(), uint64Val);
                    auto shiftedBits = builder->emitShr(
                        builder->getUInt64Type(),
                        uint64Val,
                        builder->getIntValue(builder->getIntType(), 32));
                    auto highBits = builder->emitBitCast(builder->getUIntType(), shiftedBits);
                    auto dstAddr1 = builder->emitFieldAddress(
                        uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
                    builder->emitStore(dstAddr1, lowBits);
                    auto dstAddr2 = builder->emitFieldAddress(
                        uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
                    builder->emitStore(dstAddr2, highBits);
                    advanceOffset(8);
                }
            }
        };

        IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType)
        {
            IRBuilder builder(sharedContext->module);
            builder.setInsertBefore(type);
            auto anyValInfo = ensureAnyValueType(anyValueType);

            auto func = builder.createFunc();

            StringBuilder nameSb;
            nameSb << "packAnyValue" << getIntVal(anyValueType->getSize());
            builder.addNameHintDecoration(func, nameSb.getUnownedSlice());
            // Currently we don't add linkage to the generated func, since we
            // do not have a way to compute mangled names from an IR entity.
            // This will leads to duplicate packing functions in linked code
            // but there won't be correctness issues.

            auto funcType = builder.getFuncType(1, &type, anyValInfo->type);
            func->setFullType(funcType);
            builder.setInsertInto(func);

            builder.emitBlock();

            auto param = builder.emitParam(type);
            auto concreteTypedVar = builder.emitVar(type);
            builder.emitStore(concreteTypedVar, param);
            auto resultVar = builder.emitVar(anyValInfo->type);

            // Initialize fields to 0 to prevent downstream compiler error.
            for (uint32_t offset = 0; offset < (uint32_t)anyValInfo->fieldKeys.getCount(); offset++)
            {
                auto fieldAddr = builder.emitFieldAddress(
                    builder.getPtrType(builder.getUIntType()),
                    resultVar,
                    anyValInfo->fieldKeys[offset]
                );
                builder.emitStore(fieldAddr, builder.getIntValue(builder.getUIntType(), 0));
            }

            TypePackingContext context;
            context.anyValInfo = anyValInfo;
            context.fieldOffset = context.intraFieldOffset = 0;
            context.uintPtrType = builder.getPtrType(builder.getUIntType());
            context.anyValueVar = resultVar;
            emitMarshallingCode(&builder, &context, concreteTypedVar);

            auto load = builder.emitLoad(resultVar);
            builder.emitReturn(load);
            return func;
        }

        struct TypeUnpackingContext : TypeMarshallingContext
        {
            virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
            {
                switch (dataType->getOp())
                {
                case kIROp_IntType:
                case kIROp_FloatType:
                {
                    ensureOffsetAt4ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        auto srcVal = builder->emitLoad(srcAddr);
                        srcVal = builder->emitBitCast(dataType, srcVal);
                        builder->emitStore(concreteVar, srcVal);
                    }
                    advanceOffset(4);
                    break;
                }
                case kIROp_BoolType:
                {
                    ensureOffsetAt4ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        auto srcVal = builder->emitLoad(srcAddr);
                        srcVal = builder->emitNeq(srcVal, builder->getIntValue(builder->getUIntType(), 0));
                        builder->emitStore(concreteVar, srcVal);
                    }
                    advanceOffset(4);
                    break;
                }
                case kIROp_UIntType:
                {
                    ensureOffsetAt4ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        auto srcVal = builder->emitLoad(srcAddr);
                        builder->emitStore(concreteVar, srcVal);
                    }
                    advanceOffset(4);
                    break;
                }
                case kIROp_Int16Type:
                case kIROp_UInt16Type:
                case kIROp_HalfType:
                {
                    ensureOffsetAt2ByteBoundary();
                    if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                    {
                        auto srcAddr = builder->emitFieldAddress(
                            uintPtrType,
                            anyValueVar,
                            anyValInfo->fieldKeys[fieldOffset]);
                        auto srcVal = builder->emitLoad(srcAddr);
                        if (intraFieldOffset == 0)
                        {
                            srcVal = builder->emitBitAnd(
                                srcVal->getFullType(), srcVal,
                                builder->getIntValue(builder->getUIntType(), 0xFFFF));
                        }
                        else
                        {
                            srcVal = builder->emitShr(
                                srcVal->getFullType(), srcVal,
                                builder->getIntValue(builder->getUIntType(), 16));
                        }
                        if (dataType->getOp() == kIROp_Int16Type)
                        {
                            srcVal = builder->emitCast(builder->getType(kIROp_Int16Type), srcVal);
                        }
                        else
                        {
                            srcVal = builder->emitCast(builder->getType(kIROp_UInt16Type), srcVal);
                        }
                        if (dataType->getOp() == kIROp_HalfType)
                        {
                            srcVal = builder->emitBitCast(dataType, srcVal);
                        }
                        builder->emitStore(concreteVar, srcVal);
                    }
                    advanceOffset(2);
                    break;
                }
                case kIROp_UInt64Type:
                case kIROp_Int64Type:
                case kIROp_DoubleType:
                case kIROp_Int8Type:
                case kIROp_UInt8Type:
                    SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
                    break;
                default:
                    SLANG_UNREACHABLE("unknown basic type");
                }
            }

            virtual void marshalResourceHandle(
                IRBuilder* builder, IRType* dataType, IRInst* concreteVar) override
            {
                ensureOffsetAt4ByteBoundary();
                if (fieldOffset + 1 < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
                {
                    auto srcAddr = builder->emitFieldAddress(
                        uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset]);
                    auto lowBits = builder->emitLoad(srcAddr);
                    
                    auto srcAddr1 = builder->emitFieldAddress(
                        uintPtrType, anyValueVar, anyValInfo->fieldKeys[fieldOffset + 1]);
                    auto highBits = builder->emitLoad(srcAddr1);

                    auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
                    combinedBits = builder->emitBitCast(dataType, combinedBits);
                    builder->emitStore(concreteVar, combinedBits);
                    advanceOffset(8);
                }
            }
        };

        IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType)
        {
            IRBuilder builder(sharedContext->module);
            builder.setInsertBefore(type);
            auto anyValInfo = ensureAnyValueType(anyValueType);

            auto func = builder.createFunc();

            StringBuilder nameSb;
            nameSb << "unpackAnyValue" << getIntVal(anyValueType->getSize());
            builder.addNameHintDecoration(func, nameSb.getUnownedSlice());

            auto funcType = builder.getFuncType(1, &anyValInfo->type, type);
            func->setFullType(funcType);
            builder.setInsertInto(func);

            builder.emitBlock();

            auto param = builder.emitParam(anyValInfo->type);
            auto anyValueVar = builder.emitVar(anyValInfo->type);
            builder.emitStore(anyValueVar, param);
            auto resultVar = builder.emitVar(type);

            TypeUnpackingContext context;
            context.anyValInfo = anyValInfo;
            context.fieldOffset = context.intraFieldOffset = 0;
            context.uintPtrType = builder.getPtrType(builder.getUIntType());
            context.anyValueVar = anyValueVar;
            emitMarshallingCode(&builder, &context, resultVar);
            auto load = builder.emitLoad(resultVar);
            builder.emitReturn(load);
            return func;
        }

        // Ensures the marshalling functions between `type` and `anyValueType` are already generated.
        // Returns the generated marshalling functions.
        MarshallingFunctionSet ensureMarshallingFunc(IRType* type, IRAnyValueType* anyValueType)
        {
            auto size = getIntVal(anyValueType->getSize());
            MarshallingFunctionKey key;
            key.originalType = type;
            key.anyValueSize = size;
            MarshallingFunctionSet funcSet;
            if (mapTypeMarshalingFunctions.tryGetValue(key, funcSet))
                return funcSet;
            funcSet.packFunc = generatePackingFunc(type, anyValueType);
            funcSet.unpackFunc = generateUnpackingFunc(type, anyValueType);
            mapTypeMarshalingFunctions[key] = funcSet;
            return funcSet;
        }

        void processPackInst(IRPackAnyValue* packInst)
        {
            auto operand = packInst->getValue();
            auto func = ensureMarshallingFunc(
                operand->getDataType(),
                cast<IRAnyValueType>(packInst->getDataType()));
            IRBuilder builderStorage(sharedContext->module);
            auto builder = &builderStorage;
            builder->setInsertBefore(packInst);
            auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand);
            packInst->replaceUsesWith(callInst);
            packInst->removeAndDeallocate();
        }

        void processUnpackInst(IRUnpackAnyValue* unpackInst)
        {
            auto operand = unpackInst->getValue();
            auto func = ensureMarshallingFunc(
                unpackInst->getDataType(),
                cast<IRAnyValueType>(operand->getDataType()));
            IRBuilder builderStorage(sharedContext->module);
            auto builder = &builderStorage;
            builder->setInsertBefore(unpackInst);
            auto callInst = builder->emitCallInst(unpackInst->getDataType(), func.unpackFunc, 1, &operand);
            unpackInst->replaceUsesWith(callInst);
            unpackInst->removeAndDeallocate();
        }

        void processAnyValueType(IRAnyValueType* type)
        {
            auto info = ensureAnyValueType(type);
            type->replaceUsesWith(info->type);
        }

        void processInst(IRInst* inst)
        {
            if (auto packInst = as<IRPackAnyValue>(inst))
            {
                processPackInst(packInst);
            }
            else if (auto unpackInst = as<IRUnpackAnyValue>(inst))
            {
                processUnpackInst(unpackInst);
            }
        }

        void processModule()
        {
            // We start by initializing our shared IR building state,
            // since we will re-use that state for any code we
            // generate along the way.
            //
            sharedContext->addToWorkList(sharedContext->module->getModuleInst());

            while (sharedContext->workList.getCount() != 0)
            {
                IRInst* inst = sharedContext->workList.getLast();

                sharedContext->workList.removeLast();
                sharedContext->workListSet.remove(inst);

                processInst(inst);

                for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
                {
                    sharedContext->addToWorkList(child);
                }
            }

            // Finally, replace all `AnyValueType` with the actual struct type that implements it.
            for (auto inst : sharedContext->module->getModuleInst()->getChildren())
            {
                if (auto anyValueType = as<IRAnyValueType>(inst))
                    processAnyValueType(anyValueType);
            }
            sharedContext->mapInterfaceRequirementKeyValue.clear();
        }
    };

    void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext)
    {
        AnyValueMarshallingContext context;
        context.sharedContext = sharedContext;
        context.processModule();
    }

    SlangInt alignUp(SlangInt x, SlangInt alignment)
    {
        return (x + alignment - 1) / alignment * alignment;
    }

    SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
    {
        switch (type->getOp())
        {
        case kIROp_IntType:
        case kIROp_FloatType:
        case kIROp_UIntType:
        case kIROp_BoolType:
            return alignUp(offset, 4) + 4;
        case kIROp_UInt64Type:
        case kIROp_Int64Type:
        case kIROp_DoubleType:
            return -1;
        case kIROp_Int16Type:
        case kIROp_UInt16Type:
        case kIROp_HalfType:
            return alignUp(offset, 2) + 2;
        case kIROp_UInt8Type:
        case kIROp_Int8Type:
            return -1;
        case kIROp_VectorType:
        {
            auto vectorType = static_cast<IRVectorType*>(type);
            auto elementType = vectorType->getElementType();
            auto elementCount = getIntVal(vectorType->getElementCount());
            for (IRIntegerValue i = 0; i < elementCount; i++)
            {
                offset = _getAnyValueSizeRaw(elementType, offset);
                if (offset < 0) return offset;
            }
            return offset;
        }
        case kIROp_MatrixType:
        {
            auto matrixType = static_cast<IRMatrixType*>(type);
            auto elementType = matrixType->getElementType();
            auto colCount = getIntVal(matrixType->getColumnCount());
            auto rowCount = getIntVal(matrixType->getRowCount());
            for (IRIntegerValue i = 0; i < colCount; i++)
            {
                for (IRIntegerValue j = 0; j < rowCount; j++)
                {
                    offset = _getAnyValueSizeRaw(elementType, offset);
                    if (offset < 0) return offset;
                }
            }
            return offset;
        }
        case kIROp_StructType:
        {
            auto structType = cast<IRStructType>(type);
            for (auto field : structType->getFields())
            {
                offset = _getAnyValueSizeRaw(field->getFieldType(), offset);
                if (offset < 0) return offset;
            }
            return offset;
        }
        case kIROp_ArrayType:
        {
            auto arrayType = cast<IRArrayType>(type);
            for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++)
            {
                offset = _getAnyValueSizeRaw(arrayType->getElementType(), offset);
                if (offset < 0) return offset;
            }
            return offset;
        }
        case kIROp_AnyValueType:
        {
            auto anyValueType = cast<IRAnyValueType>(type);
            return alignUp(offset, 4) + (SlangInt)getIntVal(anyValueType->getSize());
        }
        case kIROp_TupleType:
        {
            auto tupleType = cast<IRTupleType>(type);
            for (UInt i = 0; i < tupleType->getOperandCount(); i++)
            {
                auto elementType = tupleType->getOperand(i);
                offset = _getAnyValueSizeRaw((IRType*)elementType, offset);
                if (offset < 0) return offset;
            }
            return offset;
        }
        case kIROp_WitnessTableType:
        case kIROp_WitnessTableIDType:
        case kIROp_RTTIHandleType:
        {
            return alignUp(offset, 4) + kRTTIHandleSize;
        }
        case kIROp_InterfaceType:
        {
            auto interfaceType = cast<IRInterfaceType>(type);
            auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
            size += kRTTIHeaderSize;
            return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
        }
        case kIROp_AssociatedType:
        {
            auto associatedType = cast<IRAssociatedType>(type);
            SlangInt maxSize = 0;
            for (UInt i = 0; i < associatedType->getOperandCount(); i++)
                maxSize = Math::Max(maxSize, _getAnyValueSizeRaw((IRType*)associatedType->getOperand(i), offset));
            return maxSize;
        }
        case kIROp_ThisType:
        {
            auto thisType = cast<IRThisType>(type);
            auto interfaceType = thisType->getConstraintType();
            auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
            return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
        }
        case kIROp_ExtractExistentialType:
        {
            auto existentialValue = type->getOperand(0);
            auto interfaceType = cast<IRInterfaceType>(existentialValue->getDataType());
            auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc);
            return alignUp(offset, 4) + alignUp((SlangInt)size, 4);
        }
        case kIROp_LookupWitness:
        {
            auto witnessTableVal = type->getOperand(0);
            auto key = type->getOperand(1);
            IRType* assocType = nullptr;
            if (auto witnessTableType = as<IRWitnessTableTypeBase>(witnessTableVal->getDataType()))
            {
                auto interfaceType = as<IRInterfaceType>(witnessTableType->getConformanceType());

                // Walk through interface operands to find a match, the result should be an 
                // associated type entry.
                //
                for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++)
                {
                    auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii));
                    if (entry->getRequirementKey() == key && 
                        as<IRAssociatedType>(entry->getRequirementVal()))
                    {
                        assocType = (IRType*)entry->getRequirementVal();
                        break;
                    }
                }
            }

            if (!assocType)
                return -1;
            
            IRIntegerValue anyValueSize = kInvalidAnyValueSize;
            for (UInt i = 0; i < assocType->getOperandCount(); i++)
            {
                anyValueSize = Math::Min(
                    anyValueSize,
                    SharedGenericsLoweringContext::getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc));
            }

            if (anyValueSize == kInvalidAnyValueSize)
                return -1;
            
            return alignUp(offset, 4) + alignUp((SlangInt)anyValueSize, 4);
        }
        default:
            if (as<IRTextureTypeBase>(type) || as<IRSamplerStateTypeBase>(type))
            {
                return alignUp(offset, 4) + 8;
            }
            return -1;
        }
    }

    SlangInt getAnyValueSize(IRType* type)
    {
        auto rawSize = _getAnyValueSizeRaw(type, 0);
        if (rawSize < 0) return rawSize;
        return alignUp(rawSize, 4);
    }
}
back to top