https://github.com/shader-slang/slang
Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-specialize-dispatch.cpp
#include "slang-ir-specialize-dispatch.h"

#include "slang-ir-generics-lowering-context.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"
#include "slang-ir-util.h"

namespace Slang
{
IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
{
    auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
    auto conformanceType = cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType();
    // Collect all witness tables of `witnessTableType` in current module.
    List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(conformanceType);

    SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
    auto block = dispatchFunc->getFirstBlock();

    // The dispatch function before modification must be in the form of
    // call(lookup_interface_method(witnessTableParam, interfaceReqKey), args)
    // We now find the relavent instructions.
    IRCall* callInst = nullptr;
    IRLookupWitnessMethod* lookupInst = nullptr;
    // Only used in debug builds as a sanity check
    [[maybe_unused]] IRReturn* returnInst = nullptr;
    for (auto inst : block->getOrdinaryInsts())
    {
        switch (inst->getOp())
        {
        case kIROp_Call:
            callInst = cast<IRCall>(inst);
            break;
        case kIROp_LookupWitness:
            lookupInst = cast<IRLookupWitnessMethod>(inst);
            break;
        case kIROp_Return:
            returnInst = cast<IRReturn>(inst);
            break;
        default:
            break;
        }
    }
    SLANG_ASSERT(callInst && lookupInst && returnInst);

    IRBuilder builderStorage(sharedContext->module);
    auto builder = &builderStorage;
    builder->setInsertBefore(dispatchFunc);

    // Create a new dispatch func to replace the existing one.
    auto newDispatchFunc = builder->createFunc();

    List<IRType*> paramTypes;
    for (auto paramInst : dispatchFunc->getParams())
    {
        paramTypes.add(paramInst->getFullType());
    }

    // Modify the first paramter from IRWitnessTable to IRWitnessTableID representing the sequential ID.
    paramTypes[0] = builder->getWitnessTableIDType((IRType*)conformanceType);

    auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType());
    newDispatchFunc->setFullType(newDipsatchFuncType);
    dispatchFunc->transferDecorationsTo(newDispatchFunc);

    builder->setInsertInto(newDispatchFunc);
    auto newBlock = builder->emitBlock();

    IRBlock* defaultBlock = nullptr;

    auto requirementKey = lookupInst->getRequirementKey();
    List<IRInst*> params;
    for (Index i = 0; i < paramTypes.getCount(); i++)
    {
        auto param = builder->emitParam(paramTypes[i]);
        if (i > 0)
            params.add(param);
    }
    auto witnessTableParam = newBlock->getFirstParam();

    // `witnessTableParam` is expected to have `IRWitnessTableID` type, which
    // will later lower into a `uint2`. We only use the first element of the uint2
    // to store the sequential ID and reserve the second 32-bit value for future
    // pointer-compatibility. We insert a member extract inst right now
    // to obtain the first element and use it in our switch statement.
    UInt elemIdx = 0;
    auto witnessTableSequentialID =
        builder->emitSwizzle(builder->getUIntType(), witnessTableParam, 1, &elemIdx);

    // Generate case blocks for each possible witness table.
    List<IRInst*> caseBlocks;
    for (Index i = 0; i < witnessTables.getCount(); i++)
    {
        auto witnessTable = witnessTables[i];
        auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>();
        if (!seqIdDecoration)
        {
            sharedContext->sink->diagnose(witnessTable->getConcreteType(), Diagnostics::typeCannotBeUsedInDynamicDispatch, witnessTable->getConcreteType());
        }

        if (i != witnessTables.getCount() - 1)
        {
            // Create a case block if we are not the last case.
            caseBlocks.add(seqIdDecoration->getSequentialIDOperand());
            builder->setInsertInto(newDispatchFunc);
            auto caseBlock = builder->emitBlock();
            caseBlocks.add(caseBlock);
        }
        else
        {
            // Generate code for the last possible value in the `default` block.
            builder->setInsertInto(newDispatchFunc);
            defaultBlock = builder->emitBlock();
            builder->setInsertInto(defaultBlock);
        }

        auto callee = findWitnessTableEntry(witnessTable, requirementKey);
        SLANG_ASSERT(callee);
        auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
        if (callInst->getDataType()->getOp() == kIROp_VoidType)
            builder->emitReturn();
        else
            builder->emitReturn(specializedCallInst);
    }

    // Emit a switch statement to call the correct concrete function based on
    // the witness table sequential ID passed in.
    builder->setInsertInto(newDispatchFunc);


    if (witnessTables.getCount() == 1)
    {
        // If there is only 1 case, no switch statement is necessary.
        builder->setInsertInto(newBlock);
        builder->emitBranch(defaultBlock);
    }
    else if (witnessTables.getCount() > 1)
    {
        auto breakBlock = builder->emitBlock();
        builder->setInsertInto(breakBlock);
        builder->emitUnreachable();

        builder->setInsertInto(newBlock);
        builder->emitSwitch(
            witnessTableSequentialID,
            breakBlock,
            defaultBlock,
            caseBlocks.getCount(),
            caseBlocks.getBuffer());
    }
    else
    {
        // We have no witness tables that implements this interface.
        // Just return a default value.
        builder->setInsertInto(newBlock);
        if (callInst->getDataType()->getOp() == kIROp_VoidType)
        {
            builder->emitReturn();
        }
        else
        {
            auto defaultValue = builder->emitDefaultConstruct(callInst->getDataType());
            builder->emitReturn(defaultValue);
        }
    }
    // Remove old implementation.
    dispatchFunc->replaceUsesWith(newDispatchFunc);
    dispatchFunc->removeAndDeallocate();

    return newDispatchFunc;
}

// Returns true if the witness table is transitively referenced through a witness table with
// linkage.
bool _isWitnessTableTransitivelyVisible(IRInst* witness)
{
    if (witness->findDecoration<IRLinkageDecoration>())
        return true;

    OrderedHashSet<IRInst*> workSet;
    List<IRInst*> workList;
    workList.add(witness);
    for (int i = 0; i < workList.getCount(); i++)
    {
        auto item = workList[i];
        if (item->findDecoration<IRLinkageDecoration>())
            return true;
        for (auto use = item->firstUse; use; use = use->nextUse)
        {
            auto user = use->getUser();
            if (user->getOp() == kIROp_WitnessTableEntry)
            {
                if (user->getParent())
                {
                    if (workSet.add(user->getParent()))
                    {
                        workList.add(user->getParent());
                    }
                }
            }
        }
    }
    return false;
}

// Ensures every witness table object has been assigned a sequential ID.
// All witness tables will have a SequentialID decoration after this function is run.
// The sequantial ID in the decoration will be the same as the one specified in the Linkage.
// Otherwise, a new ID will be generated and assigned to the witness table object, and
// the sequantial ID map in the Linkage will be updated to include the new ID, so they
// can be looked up by the user via future Slang API calls.
void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContext)
{
    StringBuilder generatedMangledName;

    auto linkage = sharedContext->targetReq->getLinkage();
    for (auto inst : sharedContext->module->getGlobalInsts())
    {
        if (inst->getOp() == kIROp_WitnessTable)
        {
            UnownedStringSlice witnessTableMangledName;
            if (auto instLinkage = inst->findDecoration<IRLinkageDecoration>())
            {
                witnessTableMangledName = instLinkage->getMangledName();
            }
            else
            {
                auto witnessTableType = as<IRWitnessTableType>(inst->getDataType());
                if (witnessTableType && witnessTableType->getConformanceType()->findDecoration<IRSpecializeDecoration>())
                {
                    // The interface is for specialization only, it would be an error if dynamic dispatch is used
                    // through the interface.
                    // Skip assigning ID for the witness table.
                    continue;
                }

                // If this witness table entry does not have a linkage,
                // we need to check if it is transitively visible via
                // associatedtypes from an existing witness table with linkage.
                // If so we still need to include this witness talbe, otherwise
                // don't assign sequential ID for it.
                if (_isWitnessTableTransitivelyVisible(inst))
                {
                    // generate a unique linkage for it.
                    static int32_t uniqueId = 0;
                    uniqueId++;
                    if (auto nameHint = inst->findDecoration<IRNameHintDecoration>())
                    {
                        generatedMangledName << nameHint->getName();
                    }
                    generatedMangledName << "_generated_witness_uuid_" << uniqueId;
                    witnessTableMangledName = generatedMangledName.getUnownedSlice();
                }
                else
                {
                    continue;
                }
            }

            // If the inst already has a SequentialIDDecoration, stop now.
            if (inst->findDecoration<IRSequentialIDDecoration>())
                continue;

            // Get a sequential ID for the witness table using the map from the Linkage.
            uint32_t seqID = 0;
            if (!linkage->mapMangledNameToRTTIObjectIndex.tryGetValue(
                witnessTableMangledName, seqID))
            {
                auto interfaceType =
                    cast<IRWitnessTableType>(inst->getDataType())->getConformanceType();
                auto interfaceLinkage = interfaceType->findDecoration<IRLinkageDecoration>();
                SLANG_ASSERT(
                    interfaceLinkage && "An interface type does not have a linkage,"
                                        "but a witness table associated with it has one.");
                auto interfaceName = interfaceLinkage->getMangledName();
                auto idAllocator =
                    linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(
                        interfaceName);
                if (!idAllocator)
                {
                    linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0;
                    idAllocator =
                        linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(
                            interfaceName);
                }
                seqID = *idAllocator;
                ++(*idAllocator);
                linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID;
            }

            // Add a decoration to the inst.
            IRBuilder builder(sharedContext->module);
            builder.setInsertBefore(inst);
            builder.addSequentialIDDecoration(inst, seqID);
        }
    }
}

// Fixes up call sites of a dispatch function, so that the witness table argument is replaced with
// its sequential ID.
void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* newDispatchFunc)
{
    List<IRInst*> users;
    for (auto use = newDispatchFunc->firstUse; use; use = use->nextUse)
    {
        users.add(use->getUser());
    }
    for (auto user : users)
    {
        if (auto call = as<IRCall>(user))
        {
            if (call->getCallee() != newDispatchFunc)
                continue;
            IRBuilder builder(sharedContext->module);
            builder.setInsertBefore(call);
            List<IRInst*> args;
            for (UInt i = 0; i < call->getArgCount(); i++)
            {
                args.add(call->getArg(i));
            }
            if (as<IRWitnessTable>(args[0]->getDataType()))
                continue;
            auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args);
            call->replaceUsesWith(newCall);
            call->removeAndDeallocate();
        }
    }
}

void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
{
    // First we ensure that all witness table objects has a sequential ID assigned.
    ensureWitnessTableSequentialIDs(sharedContext);

    // Generate specialized dispatch functions and fixup call sites.
    for (const auto& [_, dispatchFunc] : sharedContext->mapInterfaceRequirementKeyToDispatchMethods)
    {
        // Generate a specialized `switch` statement based dispatch func,
        // from the witness tables present in the module.
        auto newDispatchFunc = specializeDispatchFunction(sharedContext, dispatchFunc);

        // Fix up the call sites of newDispatchFunc to pass in sequential IDs instead of
        // witness table objects.
        fixupDispatchFuncCall(sharedContext, newDispatchFunc);
    }
}
} // namespace Slang
back to top