// slang-ir-legalize-varying-params.cpp #include "slang-ir-legalize-varying-params.h" #include "slang-ir-insts.h" namespace Slang { // This pass implements logic to "legalize" the varying parameter // signature of an entry point. // // The traditional Slang/HLSL model is to have varying input parameters // be marked with "semantics" that can either mark them as user-defined // or system-value parameters. In addition the result (return value) // of the function can be marked, and effectively works like an `out` // parameter. // // Other targets have very different models for how varying parameters // are passed: // // * GLSL/SPIR-V declare user-defined varying input/output as global variables, // and system-defined varying parameters are available as magic built-in variables. // // * CUDA compute kernels expose varying inputs as magic built-in // variables like `threadIdx`. // // * Our CPU compilation path requires the caller to pass in a `ComputeThreadVaryingInput` // that specifies the values of the critical varying parameters for compute shaders. // // While these targets differ in how they prefer to represent varying parameters, // they share the common theme that they cannot work with the varying parameter // signature of functions as written in vanilla HLSL. // // This pass in this file is responsible for walking the parameters (and result) // of each entry point in an IR module and transforming them into a form that // is legal for each target. The shared logic deals with many aspects of the // HLSL/Slang model for varying parameters that need to be "desugared" for these // targets: // // * Slang allows either an `out` parameter or the result (return value) of the // entry point to be used interchangeably, so ensuring both cases are treated // the same is handled here. // // * Slang allows a varying parameter to use a `struct` or array type, so that // we need to recursively process elements and/or fields to find the leaf // varying parameters as they will be understood by other targets. // // * As an extension of the above, `struct`-type varying parameters in Slang // may mix user-defined and system-defined inputs/outputs. // // * Slang allows for `inout` varying parameters, which need to desugar into // distinct `in` and `out` parameters for targets like GLSL. #define SYSTEM_VALUE_SEMANTIC_NAMES(M) \ M(DispatchThreadID, SV_DispatchThreadID) \ M(GroupID, SV_GroupID) \ M(GroupThreadID, SV_GroupThreadID) \ M(GroupThreadIndex, SV_GroupIndex) \ /* end */ /// A known system-value semantic name that can be applied to a parameter /// enum class SystemValueSemanticName { None = 0, // TODO: Should this enumeration be responsible for differentiating // cases where the same semantic name string is allowed in multiple stages, // or as both input/output in a single stage, and those different uses // might result in different meanings? The alternative is to always // pass around the semantic name, stage, and direction together so // that code can tell those special cases apart. #define CASE(ID, NAME) ID, SYSTEM_VALUE_SEMANTIC_NAMES(CASE) #undef CASE // TODO: There are many more system-value semantic names that we // can/should handle here, but for now I've restricted this list // to those that are necessary for translating compute shaders. }; /// A placeholder that represents the value of a legalized varying /// parameter, for the purposes of substituting it into IR code. /// struct LegalizedVaryingVal { public: enum class Flavor { None, ///< No value (conceptually a literal of type `void`) Value, ///< A simple value represented as a single `IRInst*` Address, ///< A location in memory, identified by an address in an `IRInst*` }; LegalizedVaryingVal() {} static LegalizedVaryingVal makeValue(IRInst* irInst) { return LegalizedVaryingVal(Flavor::Value, irInst); } static LegalizedVaryingVal makeAddress(IRInst* irInst) { return LegalizedVaryingVal(Flavor::Address, irInst); } Flavor getFlavor() const { return m_flavor; } IRInst* getValue() const { SLANG_ASSERT(getFlavor() == Flavor::Value); return m_irInst; } IRInst* getAddress() const { SLANG_ASSERT(getFlavor() == Flavor::Address); return m_irInst; } private: LegalizedVaryingVal(Flavor flavor, IRInst* irInst) : m_flavor(flavor) , m_irInst(irInst) {} Flavor m_flavor = Flavor::None; IRInst* m_irInst = nullptr; }; /// Materialize the value of `val` as a single IR instruction. /// /// Any IR code that is needed to materialize the value will be emitted to `builder`. IRInst* materialize(IRBuilder& builder, LegalizedVaryingVal const& val) { switch( val.getFlavor() ) { case LegalizedVaryingVal::Flavor::None: return nullptr; // TODO: should use a `void` literal case LegalizedVaryingVal::Flavor::Value: return val.getValue(); case LegalizedVaryingVal::Flavor::Address: return builder.emitLoad(val.getAddress()); default: SLANG_UNEXPECTED("unimplemented"); break; } } void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, LegalizedVaryingVal const& src) { switch( dest.getFlavor() ) { case LegalizedVaryingVal::Flavor::None: break; case LegalizedVaryingVal::Flavor::Address: builder.emitStore(dest.getAddress(), materialize(builder, src)); break; default: SLANG_UNEXPECTED("unimplemented"); break; } } void assign(IRBuilder& builder, LegalizedVaryingVal const& dest, IRInst* src) { assign(builder, dest, LegalizedVaryingVal::makeValue(src)); } /// Context for the IR pass that legalizing entry-point /// varying parameters for a target. /// /// This is an abstract base type that needs to be inherited /// to implement the appropriate policy for a particular /// compilation target. /// struct EntryPointVaryingParamLegalizeContext { // This pass will be invoked on an entire module, and will // process all entry points in that module. // public: void processModule(IRModule* module, DiagnosticSink* sink) { m_module = module; m_sink = sink; // We will use multiple IR builders during the legalization // process, to avoid having state changes on one builder // affect other builders that might be in use. // // Once the basic initialization is done, we will allow // the subtype to implement its own initialization logic // that should occur at the start of processing a module. // beginModuleImpl(); // We now search for entry-point definitions in the IR module. // All entry points should appear at the global scope. // for(auto inst : module->getGlobalInsts()) { // Entry points are IR functions. // auto func = as(inst); if(!func) continue; // Entry point functions must have the `[entryPoint]` decoration. // auto entryPointDecor = func->findDecoration(); if(!entryPointDecor) continue; // Once we find an entry point we process it immediately. // processEntryPoint(func, entryPointDecor); } } protected: // As discussed in `processModule()`, a subtype can overide // the `beginModuleImpl()` method to perform work that should // only happen once per module that is processed. // virtual void beginModuleImpl() {} // We have both per-module and per-entry-point state that // needs to be managed. The former is set up in `processModule()`, // while the latter is used during `processEntryPoint`. // // Note: It would be possible in principle to remove some // the statefullness from this pass by factoring the // per-module and per-entry-point logic into distinct types, // but then every target-specific implementation would // need to comprise two types with complicated interdependencies. // The current solution of a single type with statefullness // seems easier to manage. IRModule* m_module = nullptr; DiagnosticSink* m_sink = nullptr; IRFunc* m_entryPointFunc = nullptr; IRBlock* m_firstBlock = nullptr; IRInst* m_firstOrdinaryInst = nullptr; Stage m_stage = Stage::Unknown; void processEntryPoint(IRFunc* entryPointFunc, IREntryPointDecoration* entryPointDecor) { m_entryPointFunc = entryPointFunc; // Before diving into the work of processing an entry point, we start by // extracting a bunch of information about the entry point that will // be useful to the downstream logic. // m_stage = entryPointDecor->getProfile().getStage(); m_firstBlock = entryPointFunc->getFirstBlock(); m_firstOrdinaryInst = m_firstBlock ? m_firstBlock->getFirstOrdinaryInst() : nullptr; auto entryPointLayoutDecoration = entryPointFunc->findDecoration(); SLANG_ASSERT(entryPointLayoutDecoration); auto entryPointLayout = as(entryPointLayoutDecoration->getLayout()); SLANG_ASSERT(entryPointLayout); // Note: Of particular importance is that we extract the first/last parameters // of the function *before* we allow the subtype to perform per-entry-point // setup operations. This ensures that if the subtype adds new parameters to // the beginnign or end of the parameter list, those new parameters won't // be processed. // IRParam* firstOriginalParam = m_firstBlock ? m_firstBlock->getFirstParam() : nullptr; IRParam* lastOriginalParam = m_firstBlock ? m_firstBlock->getLastParam() : nullptr; // We allow the subtype to perform whatever setup or code generation // it wants to on a per-entry-point basis. In some cases this might // inject code into the start of the function to provide the value // of certain system-value parameters. // beginEntryPointImpl(); // We now proceed to the meat of the work. // // We start by considering the result of the entry point function // if it is non-`void`. // auto resultType = entryPointFunc->getResultType(); if( !as(resultType) ) { // We need to translate the existing function result type // into zero or more varying parameters that are legal for // the target. An entry point function result should be // processed in a way that semantically matches an `out` parameter. // auto legalResult = createLegalVaryingVal( resultType, entryPointLayout->getResultLayout(), LayoutResourceKind::VaryingOutput); // Now that we have a representation of the value(s) that will // be used to hold the entry-point result we need to transform // any `returnVal(r)` instructions in the function body to // instead assign `r` to `legalResult` and then `returnVoid`. // IRBuilder builder(m_module); for( auto block : entryPointFunc->getBlocks() ) { auto returnValInst = as(block->getTerminator()); if(!returnValInst) continue; // We have a `returnVal` instruction that returns `resultVal`. // auto resultVal = returnValInst->getVal(); // To replace the existing `returnVal` instruction we will // emit an assignment to the new legalized result (whether // a global variable, `out` parameter, etc.) and a `returnVoid`. // builder.setInsertBefore(returnValInst); assign(builder, legalResult, resultVal); builder.emitReturn(); returnValInst->removeAndDeallocate(); } } // The parameters of the entry-point function will be processed in // order to legalize them. We need to be careful when iterating // over the parameters for a few reasons: // // * The subtype-specific setup logic could have introduce parameters // at the beginning or end of the list. We defend against that by // capturing `firstOriginalParam` and `lastOriginalParam` at the // start of this function, and only iterating over that range. // // * Somehow we might have an entry point declaration but not a definition // this is unlikely but defended against because `firstOriginalParam` // and `lastOriginalParam` will be null in that case. // // * We will often be removing the parameters once we have legalized // them, so we will modify the list while traversing it. We defend // against this by capturing `nextParam` at the start of each iteration // so that we move to the same parameter next, even if the current // parameter got removed. // // * The subtype-specific logic for legalizing a specific parameter // might decide to insert new parameters to replace it. This is another // case of modifying the parameter list while iterating it, and we // defend against it with `nextParam` just like we do for the problem // of deletion. // IRParam* nextParam = nullptr; for( auto param = firstOriginalParam; param; param = nextParam ) { nextParam = param->getNextParam(); processParam(param); if(param == lastOriginalParam) break; } } virtual void beginEntryPointImpl() {} // The next level down is the per-parameter processing logic, which // like the per-module and per-entry-point levels maintains its own // state to simplify the code (avoiding lots of long parameters lists). IRParam* m_param = nullptr; IRVarLayout* m_paramLayout = nullptr; void processParam(IRParam* param) { m_param = param; // We expect and require all entry-point parameters to have layout // information assocaited with them at this point. // auto paramLayoutDecoration = param->findDecoration(); SLANG_ASSERT(paramLayoutDecoration); m_paramLayout = as(paramLayoutDecoration->getLayout()); SLANG_ASSERT(m_paramLayout); if(!isVaryingParameter(m_paramLayout)) return; // TODO: The GLSL-specific variant of this pass has several // special cases that handle entry-point parameters for things like // GS output streams and input primitive topology. // TODO: The GLSL-specific variant of this pass has special cases // to deal with user-defined varying input to RT shaders, since // these don't translate to globals in the same way as all other // GLSL varying inputs. // We need to start by detecting whether the parameter represents // an `in` or an `out`/`inout` parameter, since that will determine // the strategy we take. // auto paramType = param->getDataType(); if(auto inOutType = as(paramType)) { processInOutParam(param, inOutType); } else if(auto outType = as(paramType)) { processOutParam(param, outType); } else { processInParam(param, paramType); } } // We anticipate that some targets may need to customize the handling // of `out` and `inout` varying parameters, so we have `virtual` methods // to handle those cases, which just delegate to a default implementation // that provides baseline behavior that should in theory work for // multiple targets. // virtual void processInOutParam(IRParam* param, IRInOutType* inOutType) { processMutableParam(param, inOutType); } virtual void processOutParam(IRParam* param, IROutType* inOutType) { processMutableParam(param, inOutType); } void processMutableParam(IRParam* param, IROutTypeBase* paramPtrType) { // The deafult handling of any mutable (`out` or `inout`) parameter // will be to introduce a local variable of the corresponding // type and to use that in place of the actual parameter during // exeuction of the function. // The replacement variable will have the type of the original // parameter (the `T` in `Out` or `InOut`). // auto valueType = paramPtrType->getValueType(); // The replacement variable will be declared at the top of // the function. // IRBuilder builder(m_module); builder.setInsertBefore(m_firstOrdinaryInst); auto localVar = builder.emitVar(valueType); auto localVal = LegalizedVaryingVal::makeAddress(localVar); if( const auto inOutType = as(paramPtrType) ) { // If the parameter was an `inout` and not just an `out` // parameter, we will create one more more legal `in` // parameters to represent the incoming value, // and then assign from those legalized input(s) // into our local variable at the start of the function. // auto inputVal = createLegalVaryingVal( valueType, m_paramLayout, LayoutResourceKind::VaryingInput); assign(builder, localVal, inputVal); } // Because the `out` or `inout` parameter is represented // as a pointer, and our local variabel is also a pointer // we can directly replace all uses of the original parameter // with uses of the variable. // param->replaceUsesWith(localVar); // For both `out` and `inout` parameters, we need to // introduce one or more legalized `out` parameters // to represent the outgoing value. // auto outputVal = createLegalVaryingVal( valueType, m_paramLayout, LayoutResourceKind::VaryingOutput); // In order to have changes to our local variable become // visible in the legalized outputs, we need to assign // from the local variable to the output as the last // operation before any `return` instructions. // for( auto block : m_entryPointFunc->getBlocks() ) { auto returnInst = as(block->getTerminator()); if(!returnInst) continue; builder.setInsertBefore(returnInst); assign(builder, outputVal, localVal); } // Once we are done replacing the original parameter, // we can remove it from the function. // param->removeAndDeallocate(); } void processInParam(IRParam* param, IRType* paramType) { // Legalizing an `in` parameter is easier than a mutable parameter. // We start by creating one or more legalized `in` parameters // to represent the incoming value. // auto legalVal = createLegalVaryingVal( paramType, m_paramLayout, LayoutResourceKind::VaryingInput); // Next, we "materialize" the legalized value to produce // an `IRInst*` that represents it. // // Note: We materialize each input parameter once, at the top // of the entry point. Making a copy in this way could // introduce overhead if an input parameter is an array, // since all indexing operations will now refer to a copy // of the original array. // // TODO: We could in theory iterate over all uses of // `param` and introduce a custom replacement for each. // Such a replacement strategy could produce better code // for things like indexing into varying arrays, but at the // cost of more accesses to the input parameter data. // IRBuilder builder(m_module); builder.setInsertBefore(m_firstOrdinaryInst); IRInst* materialized = materialize(builder, legalVal); // The materialized value can be used to completely // replace the original parameter. // param->replaceUsesWith(materialized); param->removeAndDeallocate(); } // Depending on the "direction" of the parameter (`in`, `out`, `inout`) // we may need to create one or legalized variables to represented it. // // We now turn our attention to the problem of creating a legalized // value (wrapping zero or more variables/parameters) to represent // a varying parameter of a given type for a specific direction: // either input or output, but not both. // LegalizedVaryingVal createLegalVaryingVal(IRType* type, IRVarLayout* varLayout, LayoutResourceKind kind) { // The process we are going to use for creating legalized // values is going to involve recursion over the `type` // of the parameter, and there is a lot of state that // we need to carry along the way. // // Rather than have our core recursive function have // many parameters that need to be followed through // all the recursive call sites, we are going to wrap // the relevant data up in a `struct` and pass all // the information down as a bundle. auto typeLayout = varLayout->getTypeLayout(); VaryingParamInfo info; info.type = type; info.varLayout = varLayout; info.typeLayout = typeLayout; info.kind = kind; return _createLegalVaryingVal(info); } // While recursing through the type of a varying parameter, // we may need to make a recursive call on the element type // of an array, while still tracking the fact that any // leaf parameter we encounter needs to have the "outer // array brackets" taken into account when giving it a type. // // For those purposes we have the `VaryingArrayDeclaratorInfo` // type that keeps track of outer layers of array-ness // for a parameter during our recursive walk. // // It is stored as a stack-allocated linked list, where the list flows // up through the call stack. // struct VaryingArrayDeclaratorInfo { IRInst* elementCount = nullptr; VaryingArrayDeclaratorInfo* next = nullptr; }; // Here is the declaration of the bundled information we care // about when declaring a varying parameter. // struct VaryingParamInfo { // We obviously care about the type of the parameter we // need to legalize, as well as the layout of that type. // IRType* type = nullptr; IRTypeLayout* typeLayout = nullptr; // We also care about the variable layout information for // the parameter, because that includes things like the semantic // name/index, as well as any binding information that was // computed (e.g., for the `location` of GLSL user-defined // varying parameters). // // Note: the `varLayout` member may not represent a layout for // a variable of the given `type`, because we might be peeling // away layers of array-ness. Consider: // // int stuff[3] : STUFF // // When processing the parameter `stuff`, we start with `type` // being `int[3]`, but then we will recurse on `int`. At that // point the `varLayout` will still refer to `stuff` with its // semantic of `STUFF`, but the `type` and `typeLayout` will // refer to the `int` type. // IRVarLayout* varLayout = nullptr; // As discussed above, sometimes `varLayout` will refer to an // outer declaration of array type, while `type` and `typeLayout` // refer to an element type (perhaps nested). // // The `arrayDeclarators` field stores a linked list representing // outer layers of "array brackets" that surround the variable/field // of `type`. // // If code decides to construct a leaf parameter based on `type`, // then it will need to use these `arrayDeclarators` to wrap the // type up to make it correct. // VaryingArrayDeclaratorInfo* arrayDeclarators = nullptr; // In some cases the decision-making about how to lower a parameter // will depend on the kind of varying parameter (input or output). // // TODO: We may find that there are cases where a target wants to // support true `inout` varying parameters, and `LayoutResourceKind` // cannot currently handle those. // LayoutResourceKind kind = LayoutResourceKind::None; // When we arrive at a leaf parameter/field, we can identify whether // it is a user-defined or system-value varying based on its semantic name. // // For convenience, target-specific subtypes only need to understand // the enumerated `systemValueSemanticName` rather than needing to // implement their own parsing of semantic name strings. // SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None; }; LegalizedVaryingVal _createLegalVaryingVal(VaryingParamInfo const& info) { // By default, when we seek to creating a legalized value // for a varying parameter, we will look at its type to // decide what to do. // // For most basic types, we will immediately delegate to the // base case (which will use target-specific logic). // // Note: The logic here will always fully scalarize the input // type, gernerated multiple SOA declarations if the input // was AOS. That choice is required for some cases in GLSL, // and seems to be a reasonable default policy, but it could // lead to some performance issues for shaders that rely // on varying arrays. // // TODO: Consider whether some carefully designed early-out // checks could avoid full scalarization when it is possible // to avoid. Those early-out cases would probably need to // align with the layout logic that is assigning `location`s // to varying parameters. // auto type = info.type; if (as(type)) { return createSimpleLegalVaryingVal(info); } else if( as(type) ) { return createSimpleLegalVaryingVal(info); } else if( as(type) ) { return createSimpleLegalVaryingVal(info); } else if( as(type) ) { // Note: For now we are handling matrix types in a varying // parameter list as if they were ordinary types like // scalars and vectors. This works well enough for simple // stuff, and is unlikely to see much use anyway. // // TODO: A more correct implementation will probably treat // a matrix-type varying parameter as if it was syntax // sugar for an array of rows. // return createSimpleLegalVaryingVal(info); } else if( auto arrayType = as(type) ) { // A varying parameter of array type is an interesting beast, // because depending on the element type of the array we // might end up needing to generate multiple parameters in // struct-of-arrays (SOA) fashion. This will notably // come up in the case where the element type is a `struct`, // with fields that mix both user-defined and system-value // semantics. // auto elementType = arrayType->getElementType(); auto elementCount = arrayType->getElementCount(); auto arrayLayout = as(info.typeLayout); SLANG_ASSERT(arrayLayout); auto elementTypeLayout = arrayLayout->getElementTypeLayout(); // We are going to recursively apply legalization to the // element type of the array, but when doing so we will // pass down information about the outer "array brackets" // that this type represented. // VaryingArrayDeclaratorInfo arrayDeclarator; arrayDeclarator.elementCount = elementCount; arrayDeclarator.next = info.arrayDeclarators; VaryingParamInfo elementInfo = info; elementInfo.type = elementType; elementInfo.typeLayout = elementTypeLayout; elementInfo.arrayDeclarators = &arrayDeclarator; return _createLegalVaryingVal(elementInfo); } else if( auto streamType = as(type)) { // Handling a geometry shader stream output type like // `TriangleStream` is similar to handling an array, // but we do *not* pass down a "declarator" to note // the wrapping type. // // This choice is appropriate for GLSL because geometry // shader outputs are just declared as their per-vertex // types and not wrapped in array or stream types. // // TODO: If we ever need to legalize geometry shaders for // a target with different rules we might need to revisit // this choice. // auto elementType = streamType->getElementType(); auto streamLayout = as(info.typeLayout); SLANG_ASSERT(streamLayout); auto elementTypeLayout = streamLayout->getElementTypeLayout(); VaryingParamInfo elementInfo = info; elementInfo.type = elementType; elementInfo.typeLayout = elementTypeLayout; return _createLegalVaryingVal(elementInfo); } // Note: This file is currently missing the case for handling a varying `struct`. // The relevant logic is present in `slang-ir-glsl-legalize`, but it would add // a lot of complexity to this file to include it now. // // The main consequence of this choice is that this pass doesn't support varying // parameters wrapped in `struct`s for the targets that require this pass // (currently CPU and CUDA). // // TODO: Copy over the relevant logic from the GLSL-specific pass, as part of // readying this file to handle the needs of all targets. // else { // When no special case matches, we assume the parameter // has a simple type that we can handle directly. // return createSimpleLegalVaryingVal(info); } } LegalizedVaryingVal createSimpleLegalVaryingVal(VaryingParamInfo const& info) { // At this point we've bottomed out in the type-based recursion // and we have a leaf parameter of some simple type that should // also have a single semantic name/index to work with. // TODO: This seems like the right place to "wrap" the type back // up in layers of array-ness based on the outer array brackets // that were accumulated. // Our first order of business will be to check whether the // parameter represents a system-value parameter. // auto varLayout = info.varLayout; auto semanticInst = varLayout->findSystemValueSemanticAttr(); if( semanticInst ) { // We will compare the semantic name against our list of // system-value semantics using conversion to lower-case // to achieve a case-insensitive comparison (this is // necessary because semantics in HLSL/Slang do not // treat case as significant). // // TODO: It would be nice to have a case-insensitive // comparsion operation on `UnownedStringSlice` to // avoid all the `String`s we crete and thren throw // away here. // String semanticNameSpelling = semanticInst->getName(); auto semanticName = semanticNameSpelling.toLower(); SystemValueSemanticName systemValueSemanticName = SystemValueSemanticName::None; #define CASE(ID, NAME) \ if(semanticName == String(#NAME).toLower()) \ { \ systemValueSemanticName = SystemValueSemanticName::ID; \ } \ else SYSTEM_VALUE_SEMANTIC_NAMES(CASE) #undef CASE { // no match } if( systemValueSemanticName != SystemValueSemanticName::None ) { // If the leaf parameter has a system-value semantic, then // we need to translate the system value in whatever way // is appropraite for the target. // // TODO: The logic here is missing the behavior from the // GLSL-specific pass that handles type conversion when // a user-declared system-value parameter might not // match the type that was expected exactly (e.g., they // declare a `uint2` but the parameter is a `uint3`). // VaryingParamInfo systemValueParamInfo = info; systemValueParamInfo.systemValueSemanticName = systemValueSemanticName; return createLegalSystemVaryingValImpl(systemValueParamInfo); } // TODO: We should seemingly do something if the semantic name // implies a system-value semantic (starts with `SV_`) but we // didn't find a match. // // In practice, this is probably something that should be handled // at the layout level (`slang-parameter-binding.cpp`), and the // layout for a parameter should include the `SystemValueSemanticName` // as an enumerated value rather than a string (so that downstream // code doesn't have to get into the business of parsing it). } // If there was semantic applied to the parameter *or* the semantic // wasn't recognized as a system-value semantic, then we need // to do whatever target-specific logic is required to legalize // a user-defined varying parameter. // return createLegalUserVaryingValImpl(info); } // The base type will provide default implementations of the logic // for creating user-defined and system-value varyings, but in // each case the default logic will simply diagnose an error. // // For targets that support either case, it is essential to // override these methods with appropriate logic. virtual LegalizedVaryingVal createLegalUserVaryingValImpl(VaryingParamInfo const& info) { return diagnoseUnsupportedUserVal(info); } virtual LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) { return diagnoseUnsupportedSystemVal(info); } // As a utility for target-specific subtypes, we define a routine // to diagnose the case of a system-value semantic that isn't // understood by the target. LegalizedVaryingVal diagnoseUnsupportedSystemVal(VaryingParamInfo const& info) { SLANG_UNUSED(info); m_sink->diagnose(m_param, Diagnostics::unimplemented, "this target doesn't support this system-defined varying parameter"); return LegalizedVaryingVal(); } LegalizedVaryingVal diagnoseUnsupportedUserVal(VaryingParamInfo const& info) { SLANG_UNUSED(info); m_sink->diagnose(m_param, Diagnostics::unimplemented, "this target doesn't support this user-defined varying parameter"); return LegalizedVaryingVal(); } // There are some cases of system-value inputs that can be derived // from other inputs; notably compute shaders support `SV_DispatchThreadID` // and `SV_GroupIndex` which can both be derived from the more primitive // `SV_GroupID` and `SV_GroupThreadID`, together with the extents // of the thread group (which are specified with `[numthreads(...)]`). // // As a utilty to target-specific subtypes, we define helpers for // calculating the value of these derived system values from the // more primitive ones. /// Emit code to calculate `SV_DispatchThreadID` IRInst* emitCalcDispatchThreadID( IRBuilder& builder, IRType* type, IRInst* groupID, IRInst* groupThreadID, IRInst* groupExtents) { // The dispatch thread ID can be computed as: // // dispatchThreadID = groupID*groupExtents + groupThreadID // // where `groupExtents` is the X,Y,Z extents of // each thread group in threads (as given by // `[numthreads(X,Y,Z)]`). return builder.emitAdd(type, builder.emitMul(type, groupID, groupExtents), groupThreadID); } /// Emit code to calculate `SV_GroupIndex` IRInst* emitCalcGroupThreadIndex( IRBuilder& builder, IRInst* groupThreadID, IRInst* groupExtents) { auto intType = builder.getIntType(); auto uintType = builder.getBasicType(BaseType::UInt); // The group thread index can be computed as: // // groupThreadIndex = groupThreadID.x // + groupThreadID.y*groupExtents.x // + groupThreadID.z*groupExtents.x*groupExtents.z; // // or equivalently (with one less multiply): // // groupThreadIndex = (groupThreadID.z * groupExtents.y // + groupThreadID.y) * groupExtents.x // + groupThreadID.x; // // `offset = groupThreadID.z` auto zAxis = builder.getIntValue(intType, 2); IRInst* offset = builder.emitElementExtract(uintType, groupThreadID, zAxis); // `offset *= groupExtents.y` // `offset += groupExtents.y` auto yAxis = builder.getIntValue(intType, 1); offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, yAxis)); offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, yAxis)); // `offset *= groupExtents.x` // `offset += groupExtents.x` auto xAxis = builder.getIntValue(intType, 0); offset = builder.emitMul(uintType, offset, builder.emitElementExtract(uintType, groupExtents, xAxis)); offset = builder.emitAdd(uintType, offset, builder.emitElementExtract(uintType, groupThreadID, xAxis)); return offset; } // Several of the derived calcluations rely on having // access to the "group extents" of a compute shader. // That information is expected to be present on // the entry point as a `[numthreads(...)]` attribute, // and we define a convenience routine for accessing // that information. IRInst* emitCalcGroupExtents( IRBuilder& builder, IRVectorType* type) { if(auto numThreadsDecor = m_entryPointFunc->findDecoration()) { static const int kAxisCount = 3; IRInst* groupExtentAlongAxis[kAxisCount] = {}; for( int axis = 0; axis < kAxisCount; axis++ ) { auto litValue = as(numThreadsDecor->getExtentAlongAxis(axis)); if(!litValue) return nullptr; groupExtentAlongAxis[axis] = builder.getIntValue(type->getElementType(), litValue->getValue()); } return builder.emitMakeVector(type, kAxisCount, groupExtentAlongAxis); } // TODO: We may want to implement a backup option here, // in case we ever want to support compute shaders with // dynamic/flexible group size on targets that allow it. // SLANG_UNEXPECTED("Expected '[numthreads(...)]' attribute on compute entry point."); UNREACHABLE_RETURN(nullptr); } }; // With the target-independent core of the pass out of the way, we can // turn our attention to the target-specific subtypes that handle // translation of "leaf" varying parameters. struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalizeContext { // CUDA compute kernels don't support user-defined varying // input or output, and there are only a few system-value // varying inputs to deal with. // // CUDA provides built-in global parameters `threadIdx`, // `blockIdx`, and `blockDim` that we can make use of. // IRGlobalParam* threadIdxGlobalParam = nullptr; IRGlobalParam* blockIdxGlobalParam = nullptr; IRGlobalParam* blockDimGlobalParam = nullptr; // All of our system values will be exposed with the // `uint3` type, and we'll cache a pointer to that // type to void looking it up repeatedly. // IRType* uint3Type = nullptr; // Scans through and returns the first typeLayout attribute of non-zero size. static LayoutResourceKind getLayoutResourceKind(IRTypeLayout* typeLayout) { for (auto attr : typeLayout->getSizeAttrs()) { if (attr->getSize() != 0) return attr->getResourceKind(); } return LayoutResourceKind::None; } IRInst* emitOptiXAttributeFetch(int& ioBaseAttributeIndex, IRType* typeToFetch, IRBuilder* builder) { if (auto structType = as(typeToFetch)) { List fieldVals; for (auto field : structType->getFields()) { auto fieldType = field->getFieldType(); auto fieldVal = emitOptiXAttributeFetch(ioBaseAttributeIndex, fieldType, builder); if (!fieldVal) return nullptr; fieldVals.add(fieldVal); } return builder->emitMakeStruct(typeToFetch, fieldVals); } else if (auto arrayType = as(typeToFetch)) { auto elementCountInst = as(arrayType->getElementCount()); IRIntegerValue elementCount = elementCountInst->getValue(); auto elementType = arrayType->getElementType(); List elementVals; for (IRIntegerValue ii = 0; ii < elementCount; ++ii) { auto elementVal = emitOptiXAttributeFetch(ioBaseAttributeIndex, elementType, builder); if (!elementVal) return nullptr; elementVals.add(elementVal); } return builder->emitMakeArray(typeToFetch, elementVals.getCount(), elementVals.getBuffer()); } else if (auto matType = as(typeToFetch)) { auto rowCountInst = as(matType->getRowCount()); if (rowCountInst) { auto rowType = builder->getVectorType(matType->getElementType(), matType->getColumnCount()); IRType* elementType = rowType; IRIntegerValue elementCount = rowCountInst->getValue(); List elementVals; for (IRIntegerValue ii = 0; ii < elementCount; ++ii) { auto elementVal = emitOptiXAttributeFetch(ioBaseAttributeIndex, elementType, builder); if (!elementVal) return nullptr; elementVals.add(elementVal); } return builder->emitIntrinsicInst(typeToFetch, kIROp_MakeMatrix, elementVals.getCount(), elementVals.getBuffer()); } } else if (auto vecType = as(typeToFetch)) { auto elementCountInst = as(vecType->getElementCount()); IRIntegerValue elementCount = elementCountInst->getValue(); IRType* elementType = vecType->getElementType(); List elementVals; for (IRIntegerValue ii = 0; ii < elementCount; ++ii) { auto elementVal = emitOptiXAttributeFetch(ioBaseAttributeIndex, elementType, builder); if (!elementVal) return nullptr; elementVals.add(elementVal); } return builder->emitMakeVector(typeToFetch, elementVals.getCount(), elementVals.getBuffer()); } else if (const auto basicType = as(typeToFetch)) { IRIntegerValue idx = ioBaseAttributeIndex; auto idxInst = builder->getIntValue(builder->getIntType(), idx); ioBaseAttributeIndex++; IRInst* args[] = { typeToFetch, idxInst }; IRInst* getAttr = builder->emitIntrinsicInst(typeToFetch, kIROp_GetOptiXHitAttribute, 2, args); return getAttr; } return nullptr; } void beginModuleImpl() SLANG_OVERRIDE { // Because many of the varying parameters are defined // as magic globals in CUDA, we can introduce their // definitions once per module, instead of once per // entry point. // IRBuilder builder(m_module); builder.setInsertInto(m_module->getModuleInst()); // We begin by looking up the `uint` and `uint3` types. // auto uintType = builder.getBasicType(BaseType::UInt); uint3Type = builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 3)); // Next we create IR type and variable layouts that // we can use to mark the global parameters like // `threadIdx` as varying parameters instead of // uniform. // IRTypeLayout::Builder typeLayoutBuilder(&builder); typeLayoutBuilder.addResourceUsage(LayoutResourceKind::VaryingInput, 1); auto typeLayout = typeLayoutBuilder.build(); IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput); auto varLayout = varLayoutBuilder.build(); // Finaly, we construct global parameters to represent // `threadIdx`, `blockIdx`, and `blockDim`. // // Each of these parameters is given a target-intrinsic // decoration that ensures that (1) it will not get a declaration // emitted in output code, and (2) it will be referenced // by exactly the desired name (with no attempt to generate // a unique name). threadIdxGlobalParam = builder.createGlobalParam(uint3Type); builder.addTargetIntrinsicDecoration(threadIdxGlobalParam, CapabilitySet::makeEmpty(), UnownedTerminatedStringSlice("threadIdx")); builder.addLayoutDecoration(threadIdxGlobalParam, varLayout); blockIdxGlobalParam = builder.createGlobalParam(uint3Type); builder.addTargetIntrinsicDecoration(blockIdxGlobalParam, CapabilitySet::makeEmpty(), UnownedTerminatedStringSlice("blockIdx")); builder.addLayoutDecoration(blockIdxGlobalParam, varLayout); blockDimGlobalParam = builder.createGlobalParam(uint3Type); builder.addTargetIntrinsicDecoration(blockDimGlobalParam, CapabilitySet::makeEmpty(), UnownedTerminatedStringSlice("blockDim")); builder.addLayoutDecoration(blockDimGlobalParam, varLayout); } // While CUDA provides many useful system values // as built-in globals, it does not provide the // equivalent of `SV_DispatchThreadID` or // `SV_GroupIndex` as a built-in. // // We will instead synthesize those values on // entry to each kernel. IRInst* groupThreadIndex = nullptr; IRInst* dispatchThreadID = nullptr; void beginEntryPointImpl() SLANG_OVERRIDE { IRBuilder builder(m_module); builder.setInsertBefore(m_firstOrdinaryInst); // Note that we can use the built-in `blockDim` // variable to determine the group extents, // instead of inspecting the `[numthreads(...)]` // attribute. // // This choice makes our output more idomatic // as CUDA code, but might also cost a small // amount of performance by not folding in // the known constant values from `numthreads`. // // TODO: Add logic to use the values from // `numthreads` if it is present, but to fall // back to `blockDim` if not? dispatchThreadID = emitCalcDispatchThreadID( builder, uint3Type, blockIdxGlobalParam, threadIdxGlobalParam, blockDimGlobalParam); groupThreadIndex = emitCalcGroupThreadIndex( builder, threadIdxGlobalParam, blockDimGlobalParam); // Note: we don't pay attention to whether the // kernel actually makes use of either of these // system values when we synthesize them. // // We can get away with this because we know // that subsequent DCE passes will eliminate // the computations if they aren't used. // // The main alternative would be to compute // these values lazily, when they are first // referenced. While that is possible, it // requires more (and more subtle) code in this pass. } LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE { // Because all of the relevant values are either // ambiently available in CUDA, or were computed // eagerly in the entry block to the kernel // function, we can easily return the right // value to use for a system-value parameter. switch( info.systemValueSemanticName ) { case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(blockIdxGlobalParam); case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(threadIdxGlobalParam); case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex); case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID); default: return diagnoseUnsupportedSystemVal(info); } } LegalizedVaryingVal createLegalUserVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE { auto layoutResourceKind = getLayoutResourceKind(info.typeLayout); switch (layoutResourceKind) { case LayoutResourceKind::RayPayload: { IRBuilder builder(m_module); builder.setInsertBefore(m_firstOrdinaryInst); IRPtrType* ptrType = builder.getPtrType(info.type); IRInst* getRayPayload = builder.emitIntrinsicInst(ptrType, kIROp_GetOptiXRayPayloadPtr, 0, nullptr); return LegalizedVaryingVal::makeAddress(getRayPayload); // Todo: compute how many registers are required for the current payload. // If more than 32, use the above logic. // Otherwise, either use the optix_get_payload or optix_set_payload // intrinsics depending on input/output /*if (info.kind == LayoutResourceKind::VaryingInput) { } else if (info.kind == LayoutResourceKind::VaryingOutput) { } else { return diagnoseUnsupportedUserVal(info); }*/ } case LayoutResourceKind::HitAttributes: { IRBuilder builder(m_module); builder.setInsertBefore(m_firstOrdinaryInst); int ioBaseAttributeIndex = 0; IRInst* getHitAttributes = emitOptiXAttributeFetch(/*ioBaseAttributeIndex*/ ioBaseAttributeIndex, /* type to fetch */info.type, /*the builder in use*/ &builder); if (ioBaseAttributeIndex > 8) { m_sink->diagnose(m_param, Diagnostics::unexpected, "the supplied hit attribute exceeds the maximum hit attribute structure size (32 bytes)"); return LegalizedVaryingVal(); } return LegalizedVaryingVal::makeValue(getHitAttributes); } default: return diagnoseUnsupportedUserVal(info); } } }; struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalizeContext { // Slang translates compute shaders for CPU such that they always have an // initial parameter that is a `ComputeThreadVaryingInput*`, and that // type provides the essential parameters (`SV_GroupID` and `SV_GroupThreadID` // as fields). // // Our legalization pass for CPU this begins with the per-module logic // to synthesize an IR definition of that type and its fields, so that // we can use it across entry points. IRType* uintType = nullptr; IRVectorType* uint3Type = nullptr; IRType* uint3PtrType = nullptr; IRStructType* varyingInputStructType = nullptr; IRPtrType* varyingInputStructPtrType = nullptr; IRStructKey* groupIDKey = nullptr; IRStructKey* groupThreadIDKey = nullptr; void beginModuleImpl() SLANG_OVERRIDE { IRBuilder builder(m_module); builder.setInsertInto(m_module->getModuleInst()); uintType = builder.getBasicType(BaseType::UInt); uint3Type = builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 3)); uint3PtrType = builder.getPtrType(uint3Type); // As we construct the `ComputeThreadVaryingInput` type and its fields, // we mark them all as target intrinsics, which means that their // declarations will *not* be reproduced in the output code, instead // coming from the "prelude" file that already defines this type. varyingInputStructType = builder.createStructType(); varyingInputStructPtrType = builder.getPtrType(varyingInputStructType); builder.addTargetIntrinsicDecoration(varyingInputStructType, CapabilitySet::makeEmpty(), UnownedTerminatedStringSlice("ComputeThreadVaryingInput")); groupIDKey = builder.createStructKey(); builder.addTargetIntrinsicDecoration(groupIDKey, CapabilitySet::makeEmpty(), UnownedTerminatedStringSlice("groupID")); builder.createStructField(varyingInputStructType, groupIDKey, uint3Type); groupThreadIDKey = builder.createStructKey(); builder.addTargetIntrinsicDecoration(groupThreadIDKey, CapabilitySet::makeEmpty(), UnownedTerminatedStringSlice("groupThreadID")); builder.createStructField(varyingInputStructType, groupThreadIDKey, uint3Type); } // While the declaration of the `ComputeVaryingThreadInput` type // can be shared across all entry points, each entry point must // declare its own parameter to receive the varying parameters. // // We will extract the relevant fields from the `ComputeVaryingThreadInput` // at the start of kernel execution (rather than repeatedly load them // at each use site), and will also eagerly compute the derived // values for `SV_DispatchThreadID` and `SV_GroupIndex`. IRInst* groupID = nullptr; IRInst* groupThreadID = nullptr; IRInst* groupExtents = nullptr; IRInst* dispatchThreadID = nullptr; IRInst* groupThreadIndex = nullptr; void beginEntryPointImpl() SLANG_OVERRIDE { groupID = nullptr; groupThreadID = nullptr; dispatchThreadID = nullptr; IRBuilder builder(m_module); auto varyingInputParam = builder.createParam(varyingInputStructPtrType); varyingInputParam->insertBefore(m_firstBlock->getFirstChild()); builder.setInsertBefore(m_firstOrdinaryInst); groupID = builder.emitLoad( builder.emitFieldAddress(uint3PtrType, varyingInputParam, groupIDKey)); groupThreadID = builder.emitLoad( builder.emitFieldAddress(uint3PtrType, varyingInputParam, groupThreadIDKey)); // Note: we need to rely on the presence of the `[numthreads(...)]` attribute // to tell us the size of the compute thread group, which we will then use // when computing the dispatch thread ID and group thread index. // // TODO: If we ever wanted to support flexible thread-group sizes for our // CPU target, we'd need to change it so that the thread-group size can // be passed in as part of `ComputeVaryingThreadInput`. // groupExtents = emitCalcGroupExtents(builder, uint3Type); dispatchThreadID = emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents); groupThreadIndex = emitCalcGroupThreadIndex(builder, groupThreadID, groupExtents); } LegalizedVaryingVal createLegalSystemVaryingValImpl(VaryingParamInfo const& info) SLANG_OVERRIDE { // Because all of the relvant system values were synthesized // into the first block of the entry-point function, we can // just return them wherever they are referenced. // // Note that any values that were synthesized but then are // not referened will simply be eliminated as dead code // in later passes. switch( info.systemValueSemanticName ) { case SystemValueSemanticName::GroupID: return LegalizedVaryingVal::makeValue(groupID); case SystemValueSemanticName::GroupThreadID: return LegalizedVaryingVal::makeValue(groupThreadID); case SystemValueSemanticName::GroupThreadIndex: return LegalizedVaryingVal::makeValue(groupThreadIndex); case SystemValueSemanticName::DispatchThreadID: return LegalizedVaryingVal::makeValue(dispatchThreadID); default: return diagnoseUnsupportedSystemVal(info); } } }; void legalizeEntryPointVaryingParamsForCPU( IRModule* module, DiagnosticSink* sink) { CPUEntryPointVaryingParamLegalizeContext context; context.processModule(module, sink); } void legalizeEntryPointVaryingParamsForCUDA( IRModule* module, DiagnosticSink* sink) { CUDAEntryPointVaryingParamLegalizeContext context; context.processModule(module, sink); } }