https://github.com/shader-slang/slang
Tip revision: e24c5a6cb9c3347477b83abe084a09ae8f9fde0a authored by Tim Foley on 08 January 2021, 00:01:48 UTC
Fill in some missing bits of capability API (#1652)
Fill in some missing bits of capability API (#1652)
Tip revision: e24c5a6
slang-ir-specialize-function-call.cpp
// slang-ir-specialize-function-call.cpp
#include "slang-ir-specialize-function-call.h"
#include "slang-ir.h"
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
namespace Slang
{
struct FunctionParameterSpecializationContext
{
// This type implements a pass to specialize functions
// with specific types of parameters (identified by
// `condition`) to ensure that they are
// legal or optimized for a given target.
//
// We start with member variables to stand in for
// the parameters that were passed to the top-level
// `specializeFunctionParameters` function.
//
BackEndCompileRequest* compileRequest;
TargetRequest* targetRequest;
IRModule* module;
// The condition on which parameters to specialize.
FunctionCallSpecializeCondition* condition;
// Our general approach will be to think in terms
// of specializing call sites, which amount to
// `IRCall` instructions. We will keep a work list
// of call sites in the program that may be worth
// considering for specialization.
//
List<IRCall*> workList;
// Because we may need to generate specialized functions
// and generate new calls to those functions, we'll
// need some IR building state to get our work done.
//
SharedIRBuilder sharedBuilderStorage;
IRBuilder builderStorage;
IRBuilder* getBuilder() { return &builderStorage; }
// With the basic state out of the way, let's walk
// through the overall flow of the pass.
//
void processModule()
{
// We will start by initializing our IR building state.
//
sharedBuilderStorage.module = module;
sharedBuilderStorage.session = module->getSession();
builderStorage.sharedBuilder = &sharedBuilderStorage;
// Next we will populate our initial work list by
// recursively finding every single call site in the module.
//
addCallsToWorkListRec(module->getModuleInst());
// We will process the work list until it goes dry,
// treating it like a stack of work items.
//
while( workList.getCount() )
{
auto call = workList.getLast();
workList.removeLast();
// At each call site we first check whether it
// is something we can (and should) specialize,
// and if so, do it. The process of specializing
// a function may introduce new call sites that
// become candidates for specialization, so
// our work list may grow along the way.
//
if( canSpecializeCall(call) )
{
specializeCall(call);
}
}
}
// Setting up the work list is a simple recursive procedure.
//
void addCallsToWorkListRec(IRInst* inst)
{
// If we have a call site, then add it to the list.
//
if( auto call = as<IRCall>(inst) )
{
workList.add(call);
}
// Recursively walk through any children, to
// see if we uncover more call sites.
//
for( auto child : inst->getChildren() )
{
addCallsToWorkListRec(child);
}
}
// We need a way to decide for a given call site
// whether we can/must specialize it.
//
bool canSpecializeCall(IRCall* call)
{
// We can only specialize calls where the callee
// func can be statically identified, and where
// the callee is a definition (with body) rather
// than a declaration. Otherwise there is no
// way to generate a specialized callee function.
//
auto func = as<IRFunc>(call->getCallee());
if(!func)
return false;
if(!func->isDefinition())
return false;
// With the basic checks out of the way, there are
// two conditions we care about:
//
// 1. Should we specialize? This amounts to whether
// `func` has any parameters that need specialization.
// We will call those "specializable" parameters for
// lack of a better name.
//
// 2. Can we specialize? This amounts to whether the
// arguments in `call` that correspond to those
// specializable parameters are "suitable" for use
// in specialization.
//
// We are going to answer both of these queries in
// a single loop that walks over the parameters of
// `func` as well as the arguments to `call`.
//
// The loop may seem a bit awkward because we are
// doing a parallel iteration over a linked list
// (the parameters of `func`) and an array (the
// arguments of `call`).
//
bool anySpecializableParam = false;
UInt argCounter = 0;
for( auto param : func->getParams() )
{
UInt argIndex = argCounter++;
SLANG_ASSERT(argIndex < call->getArgCount());
auto arg = call->getArg(argIndex);
// If the given parameter doesn't need specialization,
// then we need to keep looking.
//
if(!doesParamNeedSpecialization(param))
continue;
// If we have run into a `param` that needs specialization,
// then our first condition is met.
//
anySpecializableParam = true;
// Now we need to check whether `arg` is actually suitable
// for specialization (our second condition). If not, we
// can bail out immediately because our second condition
// cannot be met.
//
if(!isArgSuitableForSpecialization(arg))
return false;
}
// If we exit the loop, then the second condition must have
// been met (all the arguments for specializable parameters
// were suitable for specialization), and the result of the
// query comes down to the first condition.
//
return anySpecializableParam;
}
// Of course, now we need to back-fill the predicates that
// the above function used to evaluate prameters and arguments.
bool doesParamNeedSpecialization(IRParam* param)
{
return condition->doesParamNeedSpecialization(param);
}
bool isArgSuitableForSpecialization(IRInst* inArg)
{
// Determining if an argument is suitable for
// specializing a callee function requires
// looking at its (recurisve) structure.
//
// Rather than write a recursively procedure
// here, we will be tail-recursive by using
// a simple loop.
//
IRInst* arg = inArg;
for(;;)
{
// The leaf case we care about is when the
// argument at the call site is a global
// shader parameter, because then we can
// specialize a callee to refer to the same
// global parameter directly.
//
if(as<IRGlobalParam>(arg)) return true;
// As we will see later, we can also
// specialize a call when the argument
// is the result of indexing into an
// array (`base[index]`) *if* the `base`
// of the indexing operation is also
// suitable for specialization.
//
if( arg->op == kIROp_getElement || arg->op == kIROp_Load )
{
auto base = arg->getOperand(0);
// We will "recurse" on the base of
// the indexing operation by continuing
// our loop with the `base` as our new
// argument.
//
arg = base;
continue;
}
// By default, we will *not* consider an argument
// suitable for specialization.
//
// TODO: There may be other cases that are worth
// handling here. The current code is based on
// observation of what simple shaders do in
// practice.
//
return false;
}
}
// Once we'e determined that a given call site can/should
// be specialized, we need to perform the actual specialization.
// This is where things are going to get more involved.
//
// There are a few different concerns we need to deal with
// that mean we end up having two different passes that walk
// over the parameters/arguments of the call (in addition to
// the ones we had above for determining if we can/should
// specialize in the first place).
//
// The first of the two passes determines information
// relevant to the call site, comprising both the arguments
// that will be passed to the specialized function as
// well as a "key" to identify the specialized function
// that is required.
//
// We will use the key type defined as part of the IR cloning
// infrastructure, which uses a sequence of `IRInst*`s
// to hold the state of the key:
//
typedef IRSimpleSpecializationKey Key;
// As indicated above, the information we collect about a call
// site consists of the key for the specialized function we
// will call, and a list of the arguments that will be passed
// to the call.
//
struct CallSpecializationInfo
{
Key key;
List<IRInst*> newArgs;
};
// Once we've collected the information about a call site
// we can use a dictionary to see if we already created
// a specialized version of the callee that matches its
// requirements.
//
Dictionary<Key, IRFunc*> specializedFuncs;
// If the dictionary didn't have a specialized function
// suitable for a call site, we need a second information-gathering
// pass to decide what the new parameters of the specialized
// functions should be, and what instructions the new function
// must execute in its body to set up the replacements for the
// old parameters.
//
struct FuncSpecializationInfo
{
List<IRParam*> newParams;
List<IRInst*> newBodyInsts;
List<IRInst*> replacementsForOldParameters;
};
// Before diving into how the different passes collect
// their information, we will dive into the main
// specialization logic first.
//
void specializeCall(IRCall* oldCall)
{
// We have an existing call site `oldCall` that
// we know can and should be specialized.
//
// That means the callee should be a known function
// definition, or else `canSpecializeCall` didn't
// correctly check the preconditions.
//
auto oldFunc = as<IRFunc>(oldCall->getCallee());
SLANG_ASSERT(oldFunc);
SLANG_ASSERT(oldFunc->isDefinition());
// Our first information-gathering pass will
// compute the key for the specialized function
// we want to call, and the arguments we will
// use for that call.
//
CallSpecializationInfo callInfo;
gatherCallInfo(oldCall, oldFunc, callInfo);
// Once we have gathered information on the call,
// we can check if we have an existing specialization
// that we generated before (for another call site)
// that is suitable to this call site.
//
IRFunc* newFunc = nullptr;
if( !specializedFuncs.TryGetValue(callInfo.key, newFunc) )
{
// If we didn't find a pre-existing specialized
// function, then we will go ahead and create one.
//
// We start by gathering the information from the call
// site that is relevant to generating a specialized
// callee function, which we avoided doing earlier
// because it might have been throwaway work.
//
FuncSpecializationInfo funcInfo;
gatherFuncInfo(oldCall, oldFunc, funcInfo);
// Now we use the gathered information to generate
// a new callee function based on the original
// function and the information we gathered.
//
newFunc = generateSpecializedFunc(oldFunc, funcInfo);
specializedFuncs.Add(callInfo.key, newFunc);
}
// Once we've other found or generated a specialized function
// we need to generate a call to it, and then use the new
// call as a replacement for the old one.
//
auto newCall = getBuilder()->emitCallInst(
oldCall->getFullType(),
newFunc,
callInfo.newArgs.getCount(),
callInfo.newArgs.getBuffer());
newCall->insertBefore(oldCall);
oldCall->replaceUsesWith(newCall);
oldCall->removeAndDeallocate();
}
// Before diving into the details on how we gather information
// and specialize callees, lets stop to think about what we'd
// like to do in terms of individual parameters and arguments.
//
// Suppose we are specializing both a call site C and the callee
// function F, and we are consisering a particular pair of
// a parmeter P of F, and an argument A at the call site.
//
// The full extent of information we might want to know given
// P and A is:
//
// * What arguments need to be added to the specialized call?
// * What parameters need to be added to the specialized callee?
// * What instructions are needed in the body of the specialized
// callee to synthesize the value that will stand in for P?
// * What information, if any, needs to be used to distinguish
// this specialized callee from others that might be generated for F?
//
// An easy case is when P is a parameter that doesn't need
// specialization. In that case:
//
// * The existing argument A should be used as an argument in
// the specialized call.
// * A clone P' of the existing parameter P should be used as a
// parameter of the specialized callee.
// * No additional instructions are needed in the body of
// the callee; the cloned parameter P' should stand in for P.
// * No information should be added to the specialization key
// based on P and A.
//
// The more interesting case is when P has a resource type, and
// A is some global shader parameter G.
//
// * No argument should be added at the new call site
// * No parameter should be added to the specialized callee
// * No additional instructions are needed in the body of
// the callee; the global G should stand in for P.
// * The global G should be used to distinguish this specialized
// callee from those that might be specialized for a different
// global shader parameter.
//
// As a final example, imagine that P is still a resource type,
// but A is now an indexing operation into an array: `G[idx]`:
//
// * An argument for `idx` should be added at the call site
// * A parameter `p_idx` with the same type as `idx` should be added
// to the specialized callee.
// * An instruction should be added to the specialized callee
// to compute `G[p_idx]` and use that to stand in for P.
// * The global G should still be used to distinguish this specialized
// call site from others.
//
// That's a lot of examples, I know, but hopefully it gives a
// sense of the information we are tracking and how it differs
// across the various cases. While the example only covered one
// level of indexing, the actual implementation will handle the
// case of arbitrarily many levels of indexing, which can mean
// piping through any number of additional integer parameters
// to the callee.
// The information we gather for a call site (before we know
// whether a specialize calle is needed) is just the new
// argument list, and the "key" information that distinguishes
// what specialized callee we want/need.
//
void gatherCallInfo(
IRCall* oldCall,
IRFunc* oldFunc,
CallSpecializationInfo& callInfo)
{
// The specialized callee key always needs to include
// the original function, since different functions
// will always yield different specializations.
//
callInfo.key.vals.add(oldFunc);
// The rest of the information is gathered by looking
// at parameter and argument pairs.
//
UInt oldArgCounter = 0;
for( auto oldParam : oldFunc->getParams() )
{
UInt oldArgIndex = oldArgCounter++;
auto oldArg = oldCall->getArg(oldArgIndex);
getCallInfoForParam(callInfo, oldParam, oldArg);
}
}
void getCallInfoForParam(
CallSpecializationInfo& ioInfo,
IRParam* oldParam,
IRInst* oldArg)
{
// We know that the case where a parameter
// doesn't need specialization is easy.
//
if( !doesParamNeedSpecialization(oldParam) )
{
// The new call site will use the same argument
// value as the old one, and we don't need
// to add any information to distinguish the
// specialized callee based on this paramter.
//
ioInfo.newArgs.add(oldArg);
}
else
{
// If specialization is needed, we need
// to inspect the argument value. This
// is handled with a different function
// because it needs to recurse in some cases.
//
getCallInfoForArg(ioInfo, oldArg);
}
}
void getCallInfoForArg(
CallSpecializationInfo& ioInfo,
IRInst* oldArg)
{
// The base case we care about is when the original
// argument is a global shader parameter.
//
if( auto oldGlobalParam = as<IRGlobalParam>(oldArg) )
{
// In this case we don't need to pass anything
// as an argument at the new call site (the
// global parameter will get specialized into
// the callee), but we *do* need to make sure
// that our key for identifying the specialized
// callee reflects that we are specializing
// to the chosen parameter.
//
ioInfo.key.vals.add(oldGlobalParam);
}
else if( oldArg->op == kIROp_getElement )
{
// This is the case where the `oldArg` is
// in the form `oldBase[oldIndex]`
//
auto oldBase = oldArg->getOperand(0);
auto oldIndex = oldArg->getOperand(1);
// Effectively, we act as if `oldBase` and
// `oldIndex` were passed to the callee separately,
// so that `oldBase` is an array-of-resouces and
// `oldIndex` is an ordinary integer argument.
//
// We start by recursively setting up whatever
// `oldBase` needs:
//
getCallInfoForArg(ioInfo, oldBase);
// Then we process `oldIndex` just like we
// would have an ordinary argument that doesn't
// involve specialization: add its value to
// the arguments at the new call site, and
// don't add anything to the specialization key.
//
ioInfo.newArgs.add(oldIndex);
}
else if (oldArg->op == kIROp_Load)
{
auto oldBase = oldArg->getOperand(0);
getCallInfoForArg(ioInfo, oldBase);
}
else
{
// If we fail to match any of the cases above
// then a precondition was violated in that
// `isArgSuitableForSpecialization` is allowing
// a case that this routine is not covering.
//
SLANG_UNEXPECTED("mising case in 'getCallInfoForArg'");
}
}
// The remaining information we've discussed is only
// gathered once we decide we want to generate a
// specialized function, but it follows much the same flow.
//
void gatherFuncInfo(
IRCall* oldCall,
IRFunc* oldFunc,
FuncSpecializationInfo& funcInfo)
{
UInt oldArgCounter = 0;
for( auto oldParam : oldFunc->getParams() )
{
UInt oldArgIndex = oldArgCounter++;
auto oldArg = oldCall->getArg(oldArgIndex);
// For each parameter and argument pair we will
// frame the main task as producing a value that
// will stand in for the parameter in the specialized
// function.
//
auto newVal = getSpecializedValueForParam(funcInfo, oldParam, oldArg);
// We will collect the replacement value to use
// for each of the original parameters in an array.
//
funcInfo.replacementsForOldParameters.add(newVal);
}
}
IRInst* getSpecializedValueForParam(
FuncSpecializationInfo& ioInfo,
IRParam* oldParam,
IRInst* oldArg)
{
// As always, the easy case is when the parameter of
// the original function doesn't need specialization.
//
if( !doesParamNeedSpecialization(oldParam) )
{
// The specialized callee will need a new parameter
// that fills the same role as the old one, so we
// create it here.
//
auto newParam = getBuilder()->createParam(oldParam->getFullType());
ioInfo.newParams.add(newParam);
// The new parameter will be used as the replacement
// for the old one in the specialized function.
//
return newParam;
}
else
{
// If the parameter requires specialization, then it
// is time to look at the structure of the argument.
//
return getSpecializedValueForArg(ioInfo, oldArg);
}
}
IRInst* getSpecializedValueForArg(
FuncSpecializationInfo& ioInfo,
IRInst* oldArg)
{
// The logic here parallels `gatherCallInfoForArg`,
// and only differs in what information it is gathering.
//
// As before, the base case is when we have a global
// shader parameter.
//
if( auto globalParam = as<IRGlobalParam>(oldArg) )
{
// The specialized function will not need any
// parameter in this case, and the global itself
// should be used to stand in for the original
// parameter in the specialized function.
//
return globalParam;
}
else if( oldArg->op == kIROp_getElement )
{
// This is the case where the argument is
// in the form `oldBase[oldIndex]`.
//
auto oldBase = oldArg->getOperand(0);
auto oldIndex = oldArg->getOperand(1);
// In `gatherCallInfoForArg` this case was
// handled by acting as if `oldBase` and
// `oldIndex` were being passed as two
// separate arguments.
//
// We'll follow the same structure here,
// starting by recursively processing `oldBase`
// to get a value that can stand in for it
// in the specialized callee.
//
auto newBase = getSpecializedValueForArg(ioInfo, oldBase);
// Next we'll process `oldIndex` as if it
// was an ordinary argument (not a specialized one),
// which means creating a parameter to receive its value,
// which will also stand in for `oldIndex` in
// the body of the specialized callee.
//
auto builder = getBuilder();
auto newIndex = builder->createParam(oldIndex->getFullType());
ioInfo.newParams.add(newIndex);
// Finally, we need to compute a value that
// can stand in for `oldArg` (which was
// `oldBase[oldIndex]`) in the body of the
// specialized callee.
//
// Because we have both a `newBase` and a
// `newIndex` it is natural to construct
// `newBase[newIndex]` and use that.
//
// The only complication is that we need
// to make sure that our IR builder isn't
// set to insert newly created instructions
// anywhere, since the `emit*` functions
// will try to automatically insert new
// instructions if an insertion location
// is set.
//
builder->setInsertInto(nullptr);
auto newVal = builder->emitElementExtract(
oldArg->getFullType(),
newBase,
newIndex);
// Because our new instruction wasn't
// actually inserted anywhere, we need to
// add it to our gathered list of instructions
// that should be inserted into the body of
// the specialized callee.
//
ioInfo.newBodyInsts.add(newVal);
return newVal;
}
else if (oldArg->op == kIROp_Load)
{
return getSpecializedValueForArg(ioInfo, oldArg->getOperand(0));
}
else
{
// If we don't match one of the above cases,
// then `isArgSuitableForSpecialization` is
// letting through cases that this function
// hasn't been updated to handle.
//
SLANG_UNEXPECTED("mising case in 'getSpecializedValueForArg'");
UNREACHABLE_RETURN(nullptr);
}
}
// With all of that data-gathering code out of the way,
// we are now prepared to walk through the process of
// specializing a given callee function based on
// the information we have gathered.
//
IRFunc* generateSpecializedFunc(
IRFunc* oldFunc,
FuncSpecializationInfo const& funcInfo)
{
// We will make use of the infrastructure for cloning
// IR code, that is defined in `ir-clone.{h,cpp}`.
//
// In order to do the cloning work we need an
// "environment" that will map old values to
// their replacements.
//
IRCloneEnv cloneEnv;
// Next we iterate over the parameters of the old
// function, and register each as being mapped
// to its replacement in the `funcInfo` that was
// already gathered.
//
UInt paramCounter = 0;
for( auto oldParam : oldFunc->getParams() )
{
UInt paramIndex = paramCounter++;
auto newVal = funcInfo.replacementsForOldParameters[paramIndex];
cloneEnv.mapOldValToNew.Add(oldParam, newVal);
}
// Next we will create the skeleton of the new
// specialized function, including its type.
//
// To get the type of the new function we will
// iterate over the collected list of new
// parameters (which may differ greatly from the
// parameter list of the original) and extract
// their types.
//
List<IRType*> paramTypes;
for( auto param : funcInfo.newParams )
{
paramTypes.add(param->getFullType());
}
auto builder = getBuilder();
IRType* funcType = builder->getFuncType(
paramTypes.getCount(),
paramTypes.getBuffer(),
oldFunc->getResultType());
IRFunc* newFunc = builder->createFunc();
newFunc->setFullType(funcType);
// The above step has accomplished the "first phase"
// of cloning the function (since `IRFunc`s have no
// operands).
//
// We can now use the shared IR cloning infrastructure
// to perform the second phase of cloning, which will recursively
// clone any nested decorations, blocks, and instructions.
//
cloneInstDecorationsAndChildren(
&cloneEnv,
builder->sharedBuilder,
oldFunc,
newFunc);
// We are almost done at this point, except that `newFunc`
// is lacking its parameters, as well as any of the body
// instructions that we decided were needed during
// the information-gathering steps.
//
// We will insert these instructions into the first block
// of the function, before its first ordinary instruction.
// We know that these should exist because we had as
// a precondition that `oldFunc` was a definition (so it
// has at least one block), and in valid IR every block
// has at least one ordinary instruction (its terminator).
//
auto newEntryBlock = newFunc->getFirstBlock();
SLANG_ASSERT(newEntryBlock);
auto newFirstOrdinary = newEntryBlock->getFirstOrdinaryInst();
SLANG_ASSERT(newFirstOrdinary);
// We simply iterate over the list of parameters and then
// body instructions that were produced in the information
// gathering step, and insert each before `newFirstOrdinary`,
// which has the effect or arranging them in the output
// in the order they are enumerated here.
//
for( auto newParam : funcInfo.newParams )
{
newParam->insertBefore(newFirstOrdinary);
}
for( auto newBodyInst : funcInfo.newBodyInsts )
{
newBodyInst->insertBefore(newFirstOrdinary);
}
// At this point we've created a new specialized function,
// and as such it may contain call sites that were not
// covered when we built our initial work list.
//
// Before handing the specialized function back to the
// caller, we will make sure to recursively add any
// potentially-specializable call sites to our work list.
//
addCallsToWorkListRec(newFunc);
return newFunc;
}
};
// The top-level function for invoking the specialization pass
// is straighforward. We set up the context object
// and then defer to it for the real work.
//
void specializeFunctionCalls(
BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module,
FunctionCallSpecializeCondition* condition)
{
FunctionParameterSpecializationContext context;
context.compileRequest = compileRequest;
context.targetRequest = targetRequest;
context.module = module;
context.condition = condition;
context.processModule();
}
} // namesapce Slang