https://github.com/shader-slang/slang
Raw File
Tip revision: 0586f3298fa7d554fa2682103eefba88740d6758 authored by jsmall-nvidia on 18 January 2023, 19:11:50 UTC
Upgrade slang-llvm-13.x-33 (#2600)
Tip revision: 0586f32
slang-ir-lower-generic-call.cpp
// slang-ir-lower-generic-call.cpp
#include "slang-ir-lower-generic-call.h"
#include "slang-ir-generics-lowering-context.h"
#include "slang-ir-util.h"

namespace Slang
{
    struct GenericCallLoweringContext
    {
        SharedGenericsLoweringContext* sharedContext;

        // Represents a work item for unpacking `inout` or `out` arguments after a generic call.
        struct ArgumentUnpackWorkItem
        {
            // Concrete typed destination.
            IRInst* dstArg = nullptr;
            // Packed argument.
            IRInst* packedArg = nullptr;
        };

        // Packs `arg` into a `IRAnyValue` if necessary, to make it feedable into the parameter.
        // If `arg` represents a concrete typed variable passed in to a generic `out` parameter,
        // this function indicates that it needs to be unpacked after the call by setting
        // `unpackAfterCall`.
        IRInst* maybePackArgument(
            IRBuilder* builder,
            IRType* paramType,
            IRInst* arg,
            ArgumentUnpackWorkItem& unpackAfterCall)
        {
            unpackAfterCall.dstArg = nullptr;
            unpackAfterCall.packedArg = nullptr;

            // If either paramType or argType is a pointer type
            // (because of `inout` or `out` modifiers), we extract
            // the underlying value type first.
            IRType* paramValType = paramType;
            IRType* argValType = arg->getDataType();
            IRInst* argVal = arg;
            if (auto ptrType = as<IRPtrTypeBase>(paramType))
            {
                paramValType = ptrType->getValueType();
            } 
            auto argType = arg->getDataType();
            if (auto argPtrType = as<IRPtrTypeBase>(argType))
            {
                argValType = argPtrType->getValueType();
                argVal = builder->emitLoad(arg);
            }

            // Pack `arg` if the parameter expects AnyValue but
            // `arg` is not an AnyValue.
            if (as<IRAnyValueType>(paramValType) && !as<IRAnyValueType>(argValType))
            {
                auto packedArgVal = builder->emitPackAnyValue(paramValType, argVal);
                // if parameter expects an `out` pointer, store the packed val into a
                // variable and pass in a pointer to that variable.
                if (as<IRPtrTypeBase>(paramType))
                {
                    auto tempVar = builder->emitVar(paramValType);
                    builder->emitStore(tempVar, packedArgVal);
                    // tempVar needs to be unpacked into original var after the call.
                    unpackAfterCall.dstArg = arg;
                    unpackAfterCall.packedArg = tempVar;
                    return tempVar;
                }
                else
                {
                    return packedArgVal;
                }
            }
            return arg;
        }

        IRInst* maybeUnpackValue(IRBuilder* builder, IRType* expectedType, IRType* actualType, IRInst* value)
        {
            if (as<IRAnyValueType>(actualType) && !as<IRAnyValueType>(expectedType))
            {
                auto unpack = builder->emitUnpackAnyValue(expectedType, value);
                return unpack;
            }
            return value;
        }

        // Create a dispatch function for a interface method.
        // On CPU, the dispatch function is implemented as a witness table lookup followed by
        // a function-pointer call.
        // On GPU targets, we can modify the body of the dispatch function in a follow-up
        // pass to implement it with a `switch` statement based on the type ID.
        IRFunc* _createInterfaceDispatchMethod(
            IRBuilder* builder,
            IRInterfaceType* interfaceType,
            IRInst* requirementKey,
            IRInst* requirementVal)
        {
            auto func = builder->createFunc();
            if (auto linkage = requirementKey->findDecoration<IRLinkageDecoration>())
            {
                builder->addNameHintDecoration(func, linkage->getMangledName());
            }

            auto reqFuncType = cast<IRFuncType>(requirementVal);
            List<IRType*> paramTypes;
            paramTypes.add(builder->getWitnessTableType(interfaceType));
            for (UInt i = 0; i < reqFuncType->getParamCount(); i++)
            {
                paramTypes.add(reqFuncType->getParamType(i));
            }
            auto dispatchFuncType = builder->getFuncType(paramTypes, reqFuncType->getResultType());
            func->setFullType(dispatchFuncType);
            builder->setInsertInto(func);
            builder->emitBlock();
            List<IRInst*> params;
            IRParam* witnessTableParam = builder->emitParam(paramTypes[0]);
            for (Index i = 1; i < paramTypes.getCount(); i++)
            {
                params.add(builder->emitParam(paramTypes[i]));
            }
            auto callee = builder->emitLookupInterfaceMethodInst(
                reqFuncType, witnessTableParam, requirementKey);
            auto call = (IRCall*)builder->emitCallInst(reqFuncType->getResultType(), callee, params);
            if (call->getDataType()->getOp() == kIROp_VoidType)
                builder->emitReturn();
            else
                builder->emitReturn(call);
            return func;
        }

        // If an interface dispatch method is already created, return it.
        // Otherwise, create the method.
        IRFunc* getOrCreateInterfaceDispatchMethod(
            IRBuilder* builder,
            IRInterfaceType* interfaceType,
            IRInst* requirementKey,
            IRInst* requirementVal)
        {
            if (auto func = sharedContext->mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey))
                return *func;
            auto dispatchFunc =
                _createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal);
            sharedContext->mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists(
                requirementKey, dispatchFunc);
            return dispatchFunc;
        }

        // Translate `callInst` into a call of `newCallee`, and respect the new `funcType`.
        // If `newCallee` is a lowered generic function, `specializeInst` contains the type
        // arguments used to specialize the callee.
        void translateCallInst(
            IRCall* callInst,
            IRFuncType* funcType,
            IRInst* newCallee,
            IRSpecialize* specializeInst)
        {
            List<IRType*> paramTypes;
            for (UInt i = 0; i < funcType->getParamCount(); i++)
                paramTypes.add(funcType->getParamType(i));

            IRBuilder builderStorage(sharedContext->sharedBuilderStorage);
            auto builder = &builderStorage;
            builder->setInsertBefore(callInst);

            // Process the argument list of the call.
            // For each argument, we test if it needs to be packed into an `AnyValue` for the
            // call. For `out` and `inout` parameters, they may also need to be unpacked after
            // the call, in which case we add such the argument to `argsToUnpack` so it can be
            // processed after the new call inst is emitted.
            List<IRInst*> args;
            List<ArgumentUnpackWorkItem> argsToUnpack;
            for (UInt i = 0; i < callInst->getArgCount(); i++)
            {
                auto arg = callInst->getArg(i);
                ArgumentUnpackWorkItem unpackWorkItem;
                auto newArg = maybePackArgument(builder, paramTypes[i], arg, unpackWorkItem);
                args.add(newArg);
                if (unpackWorkItem.packedArg)
                    argsToUnpack.add(unpackWorkItem);
            }
            if (specializeInst)
            {
                for (UInt i = 0; i < specializeInst->getArgCount(); i++)
                {
                    auto arg = specializeInst->getArg(i);
                    // Translate Type arguments into RTTI object.
                    if (as<IRType>(arg))
                    {
                        // We are using a simple type to specialize a callee.
                        // Generate RTTI for this type.
                        auto rttiObject = sharedContext->maybeEmitRTTIObject(arg);
                        arg = builder->emitGetAddress(
                            builder->getRTTIHandleType(),
                            rttiObject);
                    }
                    else if (arg->getOp() == kIROp_Specialize)
                    {
                        // The type argument used to specialize a callee is itself a
                        // specialization of some generic type.
                        // TODO: generate RTTI object for specializations of generic types.
                        SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types");
                    }
                    else if (arg->getOp() == kIROp_RTTIObject)
                    {
                        // We are inside a generic function and using a generic parameter
                        // to specialize another callee. The generic parameter of the caller
                        // has already been translated into an RTTI object, so we just need
                        // to pass this object down.
                    }
                    args.add(arg);
                }
            }

            // If callee returns `AnyValue` but we are expecting a concrete value, unpack it.
            auto calleeRetType = funcType->getResultType();
            auto newCall = builder->emitCallInst(calleeRetType, newCallee, args);
            auto callInstType = callInst->getDataType();
            auto unpackInst = maybeUnpackValue(builder, callInstType, calleeRetType, newCall);
            // Unpack other `out` arguments.
            for (auto& item : argsToUnpack)
            {
                auto packedVal = builder->emitLoad(item.packedArg);
                auto originalValType = cast<IRPtrTypeBase>(item.dstArg->getDataType())->getValueType();
                auto unpackedVal = builder->emitUnpackAnyValue(originalValType, packedVal);
                builder->emitStore(item.dstArg, unpackedVal);
            }
            callInst->replaceUsesWith(unpackInst);
            callInst->removeAndDeallocate();
        }

        IRInst* findInnerMostSpecializingBase(IRSpecialize* inst)
        {
            auto result = inst->getBase();
            while (auto specialize = as<IRSpecialize>(result))
                result = specialize->getBase();
            return result;
        }

        void lowerCallToSpecializedFunc(IRCall* callInst, IRSpecialize* specializeInst)
        {
            // If we see a call(specialize(gFunc, Targs), args),
            // translate it into call(gFunc, args, Targs).
            auto loweredFunc = specializeInst->getBase();
            // All callees should have already been lowered in lower-generic-functions pass.
            // For intrinsic generic functions, they are left as is, and we also need to ignore
            // them here.
            if (loweredFunc->getOp() == kIROp_Generic)
            {
                return;
            }
            else if (loweredFunc->getOp() == kIROp_Specialize)
            {
                // All nested generic functions are supposed to be flattend before this pass.
                // If they are not, they represent an intrinsic function that should not be
                // modified in this pass.
                auto innerMostFunc = findInnerMostSpecializingBase(static_cast<IRSpecialize*>(loweredFunc));
                if (innerMostFunc && innerMostFunc->getOp() == kIROp_Generic)
                {
                    innerMostFunc =
                        findInnerMostGenericReturnVal(static_cast<IRGeneric*>(innerMostFunc));
                }
                if (innerMostFunc->findDecoration<IRTargetIntrinsicDecoration>())
                    return;
                SLANG_UNEXPECTED("Nested generics specialization.");
            }
            else if (loweredFunc->getOp() == kIROp_LookupWitness)
            {
                lowerCallToInterfaceMethod(
                    callInst, cast<IRLookupWitnessMethod>(loweredFunc), specializeInst);
                return;
            }
            IRFuncType* funcType = cast<IRFuncType>(loweredFunc->getDataType());
            translateCallInst(callInst, funcType, loweredFunc, specializeInst);
        }

        void lowerCallToInterfaceMethod(IRCall* callInst, IRLookupWitnessMethod* lookupInst, IRSpecialize* specializeInst)
        {
            // If we see a call(lookup_interface_method(...), ...), we need to translate
            // all occurences of associatedtypes.

            // If `w` in `lookup_interface_method(w, ...)` is a COM interface, bail.
            if (isComInterfaceType(lookupInst->getWitnessTable()->getDataType()))
            {
                return;
            }

            auto interfaceType = cast<IRInterfaceType>(
                cast<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())
                    ->getConformanceType());
            if (isBuiltin(interfaceType))
                return;

            IRBuilder builderStorage(sharedContext->sharedBuilderStorage);
            auto builder = &builderStorage;
            builder->setInsertBefore(callInst);

            // Create interface dispatch method that bottlenecks the dispatch logic.
            auto requirementKey = lookupInst->getRequirementKey();
            auto requirementVal =
                sharedContext->findInterfaceRequirementVal(interfaceType, requirementKey);

            if (interfaceType->findDecoration<IRSpecializeDecoration>())
            {
                sharedContext->sink->diagnose(callInst->sourceLoc, Diagnostics::dynamicDispatchOnSpecializeOnlyInterface, interfaceType);
            }
            auto dispatchFunc = getOrCreateInterfaceDispatchMethod(
                builder, interfaceType, requirementKey, requirementVal);

            auto parentFunc = getParentFunc(callInst);
            // Don't process the call inst that is the one in the dispatch function itself.
            if (parentFunc == dispatchFunc)
                return;

            // Replace `callInst` with a new call inst that calls `dispatchFunc` instead, and
            // with the witness table as first argument,
            builder->setInsertBefore(callInst);
            List<IRInst*> newArgs;
            newArgs.add(lookupInst->getWitnessTable());
            for (UInt i = 0; i < callInst->getArgCount(); i++)
                newArgs.add(callInst->getArg(i));
            auto newCall =
                (IRCall*)builder->emitCallInst(callInst->getFullType(), dispatchFunc, newArgs);
            callInst->replaceUsesWith(newCall);
            callInst->removeAndDeallocate();

            // Translate the new call inst as normal, taking care of packing/unpacking inputs
            // and outputs.
            translateCallInst(
                newCall,
                cast<IRFuncType>(dispatchFunc->getFullType()),
                dispatchFunc,
                specializeInst);
        }

        void lowerCall(IRCall* callInst)
        {
            if (auto specializeInst = as<IRSpecialize>(callInst->getCallee()))
                lowerCallToSpecializedFunc(callInst, specializeInst);
            else if (auto lookupInst = as<IRLookupWitnessMethod>(callInst->getCallee()))
                lowerCallToInterfaceMethod(callInst, lookupInst, nullptr);
        }

        void processInst(IRInst* inst)
        {
            if (auto callInst = as<IRCall>(inst))
            {
                lowerCall(callInst);
            }
        }

        void processModule()
        {
            // We start by initializing our shared IR building state,
            // since we will re-use that state for any code we
            // generate along the way.
            //
            SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
            sharedBuilder->init(sharedContext->module);

            sharedContext->addToWorkList(sharedContext->module->getModuleInst());

            while (sharedContext->workList.getCount() != 0)
            {
                IRInst* inst = sharedContext->workList.getLast();

                sharedContext->workList.removeLast();
                sharedContext->workListSet.Remove(inst);

                processInst(inst);

                for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
                {
                    sharedContext->addToWorkList(child);
                }
            }
        }
    };

    void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext)
    {
        GenericCallLoweringContext context;
        context.sharedContext = sharedContext;
        context.processModule();
    }

}
back to top