https://github.com/shader-slang/slang
Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-lower-append-consume-structured-buffer.cpp
#include "slang-ir-lower-append-consume-structured-buffer.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir-layout.h"
#include "slang-ir-lower-buffer-element-type.h"

namespace Slang
{
    static void lowerStructuredBufferType(TargetRequest* target, IRHLSLStructuredBufferTypeBase* type)
    {
        IRBuilder builder(type);
        builder.setInsertBefore(type);

        auto elementType = type->getElementType();

        // Type.
        auto structType = builder.createStructType();
        StringBuilder nameSb;
        if (type->getOp() == kIROp_HLSLAppendStructuredBufferType)
            nameSb << "AppendStructuredBuffer_";
        else
            nameSb << "ConsumeStructuredBuffer_";
        getTypeNameHint(nameSb, elementType);
        nameSb << "_t";
        builder.addNameHintDecoration(structType, nameSb.produceString().getUnownedSlice());

        auto elementBufferKey = builder.createStructKey();
        builder.addNameHintDecoration(elementBufferKey, UnownedStringSlice("elements"));

        auto counterBufferKey = builder.createStructKey();
        builder.addNameHintDecoration(counterBufferKey, UnownedStringSlice("counter"));

        builder.addDecoration(elementBufferKey, kIROp_CounterBufferDecoration, counterBufferKey);

        auto elementBufferType = builder.getType(kIROp_HLSLRWStructuredBufferType, elementType);
        auto counterBufferType = builder.getType(kIROp_HLSLRWStructuredBufferType, builder.getIntType());

        builder.createStructField(structType, elementBufferKey, elementBufferType);
        builder.createStructField(structType, counterBufferKey, counterBufferType);

        // Type layout.
        auto layoutRules = getTypeLayoutRuleForBuffer(target, type);

        IRTypeLayout::Builder elementTypeLayoutBuilder(&builder);
        IRSizeAndAlignment elementSize;
        getSizeAndAlignment(target, layoutRules, elementType, &elementSize);
        elementTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::Uniform, LayoutSize((LayoutSize::RawValue)elementSize.getStride()));
        auto elementTypeLayout = elementTypeLayoutBuilder.build();

        IRStructuredBufferTypeLayout::Builder elementBufferTypeLayoutBuilder(&builder, elementTypeLayout);
        elementBufferTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::DescriptorTableSlot, 1);
        auto elementBufferTypeLayout = elementBufferTypeLayoutBuilder.build();

        IRTypeLayout::Builder counterTypeLayoutBuilder(&builder);
        counterTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::Uniform, LayoutSize(4));
        auto counterTypeLayout = counterTypeLayoutBuilder.build();

        IRStructuredBufferTypeLayout::Builder counterBufferTypeLayoutBuilder(&builder, counterTypeLayout);
        counterBufferTypeLayoutBuilder.addResourceUsage(LayoutResourceKind::DescriptorTableSlot, 1);
        auto counterBufferTypeLayout = counterBufferTypeLayoutBuilder.build();

        IRVarLayout::Builder elementBufferVarLayoutBuilder(&builder, elementBufferTypeLayout);
        elementBufferVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot)->offset = 0;
        auto elementBufferVarLayout = elementBufferVarLayoutBuilder.build();

        IRVarLayout::Builder counterBufferVarLayoutBuilder(&builder, counterBufferTypeLayout);
        counterBufferVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot)->offset = 1;
        auto counterBufferVarLayout = counterBufferVarLayoutBuilder.build();

        IRStructTypeLayout::Builder layoutBuilder(&builder);
        layoutBuilder.addField(elementBufferKey, elementBufferVarLayout);
        layoutBuilder.addField(counterBufferKey, counterBufferVarLayout);
        auto typeLayout = layoutBuilder.build();

        builder.addLayoutDecoration(structType, typeLayout);

        IRFunc* appendFunc = nullptr;
        IRFunc* consumeFunc = nullptr;
        IRFunc* getDimensionsFunc = nullptr;

        if (type->getOp() == kIROp_HLSLAppendStructuredBufferType)
        {
            // Append method.
            appendFunc = builder.createFunc();
            builder.addNameHintDecoration(appendFunc, UnownedStringSlice("AppendStructuredBuffer_Append"));
            IRType* paramTypes[] = { structType, elementType };
            auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType());
            appendFunc->setFullType(funcType);
            builder.setInsertInto(appendFunc);
            builder.emitBlock();
            auto bufferParam = builder.emitParam(structType);
            auto elementParam = builder.emitParam(elementType);
            auto elementBuffer = builder.emitFieldExtract(elementBufferType, bufferParam, elementBufferKey);
            auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey);
            IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) };
            auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs);
            auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterIncrement, 1, &counterBufferPtr);

            IRInst* getElementPtrArgs[] = { elementBuffer, oldCounter };
            auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs);

            builder.emitStore(elementBufferPtr, elementParam);
            builder.emitReturn();
        }
        else
        {
            // Consume method.
            consumeFunc = builder.createFunc();
            builder.addNameHintDecoration(consumeFunc, UnownedStringSlice("ConsumeStructuredBuffer_Consume"));
            IRType* paramTypes[] = { structType };
            auto funcType = builder.getFuncType(1, paramTypes, elementType);
            consumeFunc->setFullType(funcType);
            builder.setInsertInto(consumeFunc);
            auto firstBlock = builder.emitBlock();
            auto bufferParam = builder.emitParam(structType);
            auto elementBuffer = builder.emitFieldExtract(elementBufferType, bufferParam, elementBufferKey);
            auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey);
            IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) };
            auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs);
            auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterDecrement, 1, &counterBufferPtr);
            auto index = builder.emitSub(builder.getIntType(), oldCounter, builder.getIntValue(builder.getIntType(), 1));

            // Test if index is greater or equal than 0.
            auto geq = builder.emitGeq(index, builder.getIntValue(builder.getIntType(), 0));
            auto trueBlock = builder.emitBlock();

            auto falseBlock = builder.emitBlock();
            auto mergeBlock = builder.emitBlock();

            builder.setInsertInto(firstBlock);
            builder.emitIfElse(geq, trueBlock, falseBlock, mergeBlock);

            builder.setInsertInto(trueBlock);
            IRInst* getElementPtrArgs[] = { elementBuffer, index };
            auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs);
            auto val = builder.emitLoad(elementBufferPtr);
            builder.emitReturn(val);

            builder.setInsertInto(falseBlock);
            auto defaultVal = builder.emitDefaultConstruct(elementType);
            builder.emitReturn(defaultVal);

            builder.setInsertInto(mergeBlock);
            builder.emitUnreachable();
        }

        // GetDimensions method.
        {
            getDimensionsFunc = builder.createFunc();
            builder.addNameHintDecoration(getDimensionsFunc, UnownedStringSlice("StructuredBuffer_GetDimensions"));
            IRType* paramTypes[] = { structType };
            auto uint2Type = builder.getVectorType(builder.getUIntType(), 2);
            auto funcType = builder.getFuncType(1, paramTypes, uint2Type);
            getDimensionsFunc->setFullType(funcType);
            builder.setInsertInto(getDimensionsFunc);
            builder.emitBlock();
            auto bufferParam = builder.emitParam(structType);
            auto elementBuffer = builder.emitFieldExtract(elementBufferType, bufferParam, elementBufferKey);

            const auto dim = builder.emitIntrinsicInst(
                uint2Type,
                kIROp_StructuredBufferGetDimensions,
                1,
                &elementBuffer
            );
            builder.emitReturn(dim);
        }

        // Replace all insts with synthesized functions.
        traverseUsers(type, [&](IRInst* typeUser)
            {
                if (typeUser->getFullType() != type)
                    return;
                if (auto layoutDecor = typeUser->findDecoration<IRLayoutDecoration>())
                {
                    // Replace the original StructuredBufferVarLayout with the new StructTypeVarLayout.
                    if (auto varLayout = as<IRVarLayout>(layoutDecor->getLayout()))
                    {
                        IRBuilder subBuilder(typeUser);
                        IRVarLayout::Builder newVarLayoutBuilder(&subBuilder, typeLayout);
                        newVarLayoutBuilder.cloneEverythingButOffsetsFrom(varLayout);
                        IRVarOffsetAttr* uavOffsetAttr = nullptr;
                        IRVarOffsetAttr* descriptorTableSlotOffsetAttr = nullptr;

                        for (auto offsetAttr : varLayout->getOffsetAttrs())
                        {
                            if (offsetAttr->getResourceKind() == LayoutResourceKind::UnorderedAccess)
                                uavOffsetAttr = offsetAttr;
                            else if (offsetAttr->getResourceKind() == LayoutResourceKind::DescriptorTableSlot)
                                descriptorTableSlotOffsetAttr = offsetAttr;
                            auto info = newVarLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
                            info->offset = offsetAttr->getOffset();
                            info->space = offsetAttr->getSpace();
                            info->kind = offsetAttr->getResourceKind();
                        }
                        // If the user provided an layout offset for UAV but not for descriptor table slot, then
                        // we use the UAV offset for the descriptor table slot offset.
                        if (uavOffsetAttr && !descriptorTableSlotOffsetAttr)
                        {
                            auto info = newVarLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::DescriptorTableSlot);
                            info->offset = uavOffsetAttr->getOffset();
                            info->space = uavOffsetAttr->getSpace();
                            info->kind = LayoutResourceKind::DescriptorTableSlot;
                        }
                        auto newVarLayout = newVarLayoutBuilder.build();
                        subBuilder.addLayoutDecoration(typeUser, newVarLayout);
                        varLayout->removeAndDeallocate();
                    }
                }
                traverseUses(typeUser, [&](IRUse* use)
                    {
                        auto user = use->getUser();
                        switch (user->getOp())
                        {
                        case kIROp_StructuredBufferAppend:
                            {
                                IRBuilder subBuilder(user);
                                subBuilder.setInsertBefore(user);
                                IRInst* args[] = { user->getOperand(0), user->getOperand(1) };
                                auto call = subBuilder.emitCallInst(user->getFullType(), appendFunc, 2, args);
                                user->replaceUsesWith(call);
                                user->removeAndDeallocate();
                                break;
                            }
                        case kIROp_StructuredBufferConsume:
                            {
                                IRBuilder subBuilder(user);
                                subBuilder.setInsertBefore(user);
                                IRInst* args[] = { user->getOperand(0) };
                                auto call = subBuilder.emitCallInst(user->getFullType(), consumeFunc, 1, args);
                                user->replaceUsesWith(call);
                                user->removeAndDeallocate();
                                break;
                            }
                        case kIROp_StructuredBufferGetDimensions:
                            {
                                IRBuilder subBuilder(user);
                                subBuilder.setInsertBefore(user);
                                IRInst* args[] = { user->getOperand(0) };
                                auto call = subBuilder.emitCallInst(user->getFullType(), getDimensionsFunc, 1, args);
                                user->replaceUsesWith(call);
                                user->removeAndDeallocate();
                                break;
                            }
                        }
                    });
            });
        type->replaceUsesWith(structType);
    }

    void lowerAppendConsumeStructuredBuffers(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
    {
        SLANG_UNUSED(sink);
        for (auto globalInst : module->getGlobalInsts())
        {
            switch (globalInst->getOp())
            {
            case kIROp_HLSLAppendStructuredBufferType:
            case kIROp_HLSLConsumeStructuredBufferType:
                lowerStructuredBufferType(target, as<IRHLSLStructuredBufferTypeBase>(globalInst));
                break;
            }
        }
    }
}
back to top