https://github.com/shader-slang/slang
Tip revision: adb1131d08f28f0bc5f729e88b73cf22846c86c5 authored by Tim Foley on 05 February 2021, 17:01:36 UTC
Initial implementation of interface conjunctions (#1691)
Initial implementation of interface conjunctions (#1691)
Tip revision: adb1131
slang-ir-union.cpp
// slang-ir-union.cpp
#include "slang-ir-union.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
namespace Slang {
// This file will implement a pass to replace any union types (currently
// just tagged unions) with plain `struct` types that attempt to provide
// equivalent semantics. This will necessarily be a bit fragile, and there
// will be fundamental limits to what the translation can support without
// improved features in the target shading languages/ILs.
struct DesugarUnionTypesContext
{
// We'll start with some basic state that we need to get the job done.
//
// This includes the IR module we are to process, as well as IR building
// state that we will initialize once and then use throughout the pass.
//
IRModule* module;
SharedIRBuilder sharedBuilderStorage;
IRBuilder builderStorage;
IRBuilder* getBuilder() { return &builderStorage; }
// Because we will be replacing instructions that refer to unions with
// different logic, we'll want to remove the original instructions.
// However, we need to be careful about modifying the IR tree while also
// iterating it, and to keep things simple for ourselves we'll go ahead
// and build up a list of instruction to remove along the way, and then
// remove them all at the end.
//
List<IRInst*> instsToRemove;
// The overall flow of the pass is pretty simple, so we will walk through it now.
//
void processModule()
{
// We start by initializing our IR building state.
//
sharedBuilderStorage.session = module->session;
sharedBuilderStorage.module = module;
builderStorage.sharedBuilder = &sharedBuilderStorage;
// Next, we will search for any instruction that create or use
// union types, and process them accordingingly (usually by
// constructing a new instruction to replace them).
//
processInstRec(module->getModuleInst());
// Along the way we will build up a list of the tagged union
// types that we encountered, but we will refrain from replacing
// them until we are done (so that we always know that the instructions
// we process above refer to the original type, and not its
// replacement.
//
for( auto info : taggedUnionInfos )
{
auto taggedUnionType = info->taggedUnionType;
auto replacementInst = info->replacementInst;
// TODO: We should consider transferring decorations from the source
// type to the destination, but doing so carelessly could create
// problems, since an IR struct type shouldn't have, e.g., a
// `TaggedUnionTypeLayout` attached to it.
taggedUnionType->replaceUsesWith(replacementInst);
taggedUnionType->removeAndDeallocate();
}
// As described previously, we build up the `instsToRemove` list as
// we iterate so that we can remove them all here and not risk
// modifying the IR tree while also walking it.
//
// TODO: This might be overkill and we could conceivably just be
// a bit careful in `processInstRec`.
//
for(auto inst : instsToRemove)
{
inst->removeAndDeallocate();
}
}
// In order to replace a (tagged) union type, we will need to know
// something about it, and we will use the `TaggedUnionInfo` type
// to collect all the relevant information.
//
struct TaggedUnionInfo : public RefObject
{
// We obviously need to know the tagged union itself, and
// we will also use this structure to track the instruction
// (an IR struct type) that will replace it.
//
IRTaggedUnionType* taggedUnionType;
IRInst* replacementInst;
// In order to compute a suitable layout for the replacement
// `struct` type we need to know how the tagged union itself
// would be laid out in memory, so we require that all tagged
// unions in the generated IR have an associated (target-specific)
// layout.
//
IRTaggedUnionTypeLayout* taggedUnionTypeLayout;
// The basic approach we will use 16-byte chunks (represented as an array
// of `uint4`s) to reprent the "bulk" of a type, and then use a single field
// that could be up to 12 bytes to represent the "rest" of the type.
//
// Note that there are deeply ingrained assumptions here that all types
// are at least four bytes in size (so that unions cannot easily
// accomodate `half` value), and that any types *larger* than four bytes
// will need to be loaded/stored via multiple 4-byte loads/stores.
//
// With the basic idea out of the way, we need an IR level field
// in our struct to hold the bulk data, which comprises a "key" for
// looking up the field, and the type of the field itself. We also
// keep track of how many bytes we put in our bulk storage.
//
// The bulk field might be:
//
// - null, if none of the case types was 16 bytes or more
// - a single `uint4` for between 16 and 31 (inclusive) bytes
// - an array of `uint4`s for 32 or more bytes
//
UInt64 bulkSize = 0;
IRInst* bulkFieldKey = nullptr;
IRType* bulkFieldType = nullptr;
// The same basic idea then applies to the rest of the data.
//
// The "rest" field will be either be absent (if the size of the
// type was evently divisible by 16), a scalar `uint`, or else
// a 2- or 3-component vector of `uint`.
//
UInt64 restSize = 0;
IRInst* restFieldKey = nullptr;
IRType* restFieldType = nullptr;
// Finally, since we are currently working with tagged unions,
// we need a field to hold the tag, which will always be allocated
// after the fields that hold the bulk/rest of the payload.
//
// This field is always a single `uint`.
//
// TODO: if/when we support untagged unions, they could be handled
// by having this field be null.
//
IRInst* tagFieldKey;
};
// We will build up a list of all the tagged union types we encounter,
// so that we can replace them with the synthesized types when we are done.
//
List<RefPtr<TaggedUnionInfo>> taggedUnionInfos;
// It is possible that we will see the same tagged union type referenced
// many times in the IR, but we only want to synthesize the information
// above (including the various IR structures) once, so we also maintain
// a map from the original IR type to the corresponding information.
//
Dictionary<IRInst*, TaggedUnionInfo*> mapIRTypeToTaggedUnionInfo;
// We will process all instructions in the module in a single recursive walk.
//
void processInstRec(IRInst* inst)
{
processInst(inst);
for( auto child : inst->getChildren() )
{
processInstRec(child);
}
}
//
// At each instruction, we will check if it is one of the union-related instructions
// we need to replace, and process it accordingly.
//
void processInst(IRInst* inst)
{
switch( inst->op )
{
default:
// Any instruction not listed below either doesn't involve union types,
// or handles them in a hands-off fashion that we don't need to care about.
//
// E.g., a `load` of a union type from a constant buffer will turn into
// a load of the replacement `struct` type once we are done, and nothing
// needs to be done to the `load` instruction.
//
break;
case kIROp_TaggedUnionType:
{
// We clearly need to process the tagged union type itself, but the actual
// work is handled by other functions. All we need to do here is ensure
// that the information for this type gets generated, and then we can
// rely on the main `processModule` function to do the actual replacement later.
//
auto type = cast<IRTaggedUnionType>(inst);
getTaggedUnionInfo(type);
}
break;
case kIROp_ExtractTaggedUnionTag:
{
// The case of extracting the tag from a tagged union is relatively
// simple, because the replacement type will have a dedicated field or it.
//
// We start by finding the tagged union value the instruction is operating
// on, and then looking up the information for its type (which had
// better be a tagged union type).
//
auto taggedUnionVal = inst->getOperand(0);
auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType());
// Because the replacement type will have an explicit field for the tag,
// we can simply emit a single field-extract instruction to read its value
// out.
//
auto builder = getBuilder();
builder->setInsertBefore(inst);
auto replacement = builder->emitFieldExtract(
inst->getFullType(),
taggedUnionVal,
taggedUnionInfo->tagFieldKey);
// Now we can replace anything that used the original instruction with
// the new field-extract operation, and add this instruction to the
// list for later removal.
//
inst->replaceUsesWith(replacement);
instsToRemove.add(inst);
}
break;
case kIROp_ExtractTaggedUnionPayload:
{
// The most interesting case is when we are trying to extract a particular
// payload (one of the case types) from a union. We may need to extract
// one or more fields from the data stored in the union's replacement
// type (the bulk/rest fields), and we may also have to convert them
// to the type expected via bit-casts.
// We can start things off easily enough by extracting the tagged union
// value being operated on, as well as the information for its type.
//
auto taggedUnionVal = inst->getOperand(0);
auto taggedUnionInfo = getTaggedUnionInfo(taggedUnionVal->getDataType());
// Next we need to figure out which case is being extracted from the union.
// The operand for the case tag should be a literal by construction.
//
auto caseTagVal = inst->getOperand(1);
auto caseTagConst = as<IRIntLit>(caseTagVal);
SLANG_ASSERT(caseTagConst);
// The case type we are extracting will be the result type of the instruciton.
//
auto caseType = inst->getDataType();
//
// The tag value itself will be the index of the case type in the union
// type (and its layout).
//
auto caseTagIndex = UInt(caseTagConst->getValue());
// We can use the case tag value to look up the layout for the particular
// case type we are extracting (this will allow us to resolve byte offsets
// for fields, etc.).
//
auto taggedUnionTypeLayout = taggedUnionInfo->taggedUnionTypeLayout;
SLANG_ASSERT(caseTagIndex < UInt(taggedUnionTypeLayout->getCaseCount()));
auto caseTypeLayout = taggedUnionTypeLayout->getCaseTypeLayout(caseTagIndex);
// At this point we know the type we are trying to extract, as well
// as its layout. We will defer the actual implementation of extraction
// to a (recursive) subroutine that can extract a (sub-)field from the
// union at a given byte offset. Since we are extracting a full case
// right now, the byte offset will be zero.
//
auto payloadVal = extractPayload(
taggedUnionInfo,
taggedUnionVal,
caseType,
caseTypeLayout,
0);
// TODO: There is a significant flaw in the above approach when
// the case type might be (or contain) an array. If we have a setup
// like the following:
//
// union SomeUnion { float someCase[100]; ... }
// ...
// float result = someUnion.someCase[someIndex];
//
// The current logic would desugar this into something like:
//
// struct SomeUnion { uint4 bulk[100]; ... }
// ...
// float[] tmp = { asfloat(someUnion.bulk[0].x), asfloat(someUnion.bulk[1].x), ... }
// float result = tmp[someIndex];
//
// The result is that we copy an entire 100-element array into local memory
// just to fetch a single element, when it would be much nicer to just do:
//
// float result = asfloat(someUnion.bulk[someIndex].x);
//
// Achieving the latter code requires that rather than blindly translate
// the `extractTaggedUnionPayload` instruction into a semantically equiavlent
// value (which might lead to a big copy in the end), we should transitively
// chase down any "access chains" off of `inst` and see what leaf values are
// actually needed, and generated more tailored extraction logic for just
// the elements/fields that actually get referenced.
//
// The more refined approach can be built on top of many of the same primitives,
// so for now we will resign ourselves to the simpler but potentially less
// efficient approach.
// Now that we've extracted the value for the payload from the fields of
// the replacement struct, we can use that extracted value to replace
// this instruction, and schedule the original instruction for removal.
//
inst->replaceUsesWith(payloadVal);
instsToRemove.add(inst);
}
break;
}
}
// The `extractPayload` operation is the most important bit of translation we
// need to do to make unions work. We have as input the following:
//
IRInst* extractPayload(
// - Information about a tagged union type and its layout.
TaggedUnionInfo* taggedUnionInfo,
// - A single value of that tagged unon type.
IRInst* taggedUnionVal,
// - Type type of some "payload" field we want to extract from the union.
IRType* payloadType,
// - The memory layout of that payload type.
IRTypeLayout* payloadTypeLayout,
// - The byte offset at which we want to fetch the payload.
UInt64 payloadOffset)
{
// We are going to be building some IR code no matter what.
//
auto builder = getBuilder();
// The basic approach here will be to look at the type we
// are trying to extract from the union, and whenever possible
// recursively walk its structure so that we can express things
// in terms of extraction of smaller/simpler types.
//
if( auto irStructType = as<IRStructType>(payloadType) )
{
// A structure type is a nice recursive case: we simply
// want to extract each of its field recursively, and
// then construct a fresh value of the `struct` type.
// In all of the cases of this function we expect/require
// there to be complete type layout information for the
// types involved.
//
auto structTypeLayout = as<IRStructTypeLayout>(payloadTypeLayout);
SLANG_ASSERT(structTypeLayout);
// We are going to emit code to extract each of the fields
// and collect them to use as operands to a `makeStruct`.
//
List<IRInst*> fieldVals;
// We need to walk over the fields in the order the IR expects them
UInt fieldCounter = 0;
for( auto irField : irStructType->getFields() )
{
IRType* fieldType = irField->getFieldType();
// TODO: We need to confirm/enforce that the fields of the
// IR struct and the fields of the layout still align.
//
UInt fieldIndex = fieldCounter++;
auto fieldLayout = structTypeLayout->getFieldLayout(fieldIndex);
auto fieldTypeLayout = fieldLayout->getTypeLayout();
// The offset of the field can be computed from the base
// offset passed in, plus the reflection data for the field.
//
UInt64 fieldOffset = payloadOffset;
if(auto resInfo = fieldLayout->findOffsetAttr(LayoutResourceKind::Uniform))
fieldOffset += resInfo->getOffset();
// We make a recursive call to extract each field, expecting
// that this will bottom out eventually.
//
IRInst* fieldVal = extractPayload(
taggedUnionInfo,
taggedUnionVal,
fieldType,
fieldTypeLayout,
fieldOffset);
fieldVals.add(fieldVal);
}
// The final value is then just a new struct constructed from
// the extracted field values.
//
auto payloadVal = builder->emitMakeStruct(irStructType, fieldVals);
return payloadVal;
}
else if( auto vecType = as<IRVectorType>(payloadType) )
{
auto elementType = vecType->getElementType();
// We expect that by the time we are desugaring union types
// all vector types have literal constant values for their
// element count.
//
auto elementCountVal = vecType->getElementCount();
auto elementCountConst = as<IRIntLit>(elementCountVal);
SLANG_ASSERT(elementCountConst);
UInt elementCount = UInt(elementCountConst->getValue());
// HACK: There is currently no `VectorTypeLayout` and thus
// no way to query the layout of the elements of a vector
// type. Until that gets added we will kludge things here.
//
IRTypeLayout* elementTypeLayout = nullptr;
size_t elementSize = 0;
if(auto resInfo = payloadTypeLayout->findSizeAttr(LayoutResourceKind::Uniform))
elementSize = resInfo->getSize().getFiniteValue() / elementCount;
// Similar to the `struct` case above, we will extract a
// value for each element of the vector, and then use
// `makeVector` to construct the result value.
//
List<IRInst*> elementVals;
for(UInt ii = 0; ii < elementCount; ++ii)
{
auto elementVal = extractPayload(
taggedUnionInfo,
taggedUnionVal,
elementType,
elementTypeLayout,
payloadOffset + ii*elementSize);
elementVals.add(elementVal);
}
return builder->emitMakeVector(vecType, elementVals);
}
else if( auto matType = as<IRMatrixType>(payloadType) )
{
SLANG_UNIMPLEMENTED_X("matrix in union type");
}
else if( auto arrayType = as<IRArrayType>(payloadType) )
{
SLANG_UNIMPLEMENTED_X("array in union type");
}
else
{
// If none of the above cases match, then we assume that
// we have an individual scalar field that we need to fetch.
//
UInt64 payloadSize = 0;
if( auto resInfo = payloadTypeLayout->findSizeAttr(LayoutResourceKind::Uniform) )
{
// TODO: somebody before this point should generate an error if
// we have a `union` type that contains a potentially unbounded
// amount of data.
//
payloadSize = resInfo->getSize().getFiniteValue();
}
if( payloadSize != 4 )
{
// TODO: We should handle the case of 64-bit fields by fetching
// two `uint` values to form a `uint2`, and then using an
// appropriate bit-cast to get from `uint2` to, e.g., `double`.
//
// The case of 16-bit and smaller fields is more troublesome, but
// in the worst case we can load a `uint` and then use bitwise
// ops to extract what we need before bitcasting.
//
// The right long-term solution is for downstream languages to have
// better support for raw memory addressing.
SLANG_UNIMPLEMENTED_X("leaf union field with size other than 4 bytes");
}
// We know that we want to fetch a value of size `payloadSize`, and
// we have a known base value and an initial offset into it.
//
IRInst* baseVal = taggedUnionVal;
UInt64 offset = payloadOffset;
// We are going to refine our `baseVal` and `offset` as we go, by
// trying to narrow down the data we will access in the `struct`
// type that will provide storage for the union.
//
// The first thing we want to check is if the value sits in the
// "bulk" part of the storage, or the "rest."
//
UInt64 bulkSize = taggedUnionInfo->bulkSize;
if( offset < bulkSize )
{
// If the value starts in the bulk area, then the whole
// thing had better fit in the bulk area. The 16-byte
// granularity rules for constant buffers should ensure
// this property for us on current targets.
//
SLANG_ASSERT(offset + payloadSize <= bulkSize);
// Since we know we'll be accessing the bulk storage,
// we will extract it here. The extracted field will
// be our new base value, but the `offset` doesn't need
// to be updated since the bulk field sits at offset 0.
//
baseVal = builder->emitFieldExtract(
taggedUnionInfo->bulkFieldType,
baseVal,
taggedUnionInfo->bulkFieldKey);
// The bulk storage could be an array, if there are 32
// or more bytes of bulk storage.
//
if( auto baseArrayType = as<IRArrayType>(baseVal->getDataType()) )
{
// If an array was allocated for bulk storage then
// our leaf value resides entirely within a single
// element (due to constant buffer layout rules),
// and so we will fetch the appropriate element here.
//
// We will change our `baseVal` to the extracted element,
// and then also adjust our `offset` to be relative
// to that element.
//
size_t bulkElementSize = 16;
auto index = offset / bulkElementSize;
baseVal = builder->emitElementExtract(
baseArrayType->getElementType(),
baseVal,
builder->getIntValue(builder->getIntType(), index));
offset -= index*bulkElementSize;
}
}
else
{
// If the offset of the field we want is past the end of
// the bulk field then it must sit inside of the rest field,
// and we'll extract it here. This establishes a new
// base value, and we adjust the `offset` to be relative
// to the rest field (which starts at an offset equal to `bulkSize`).
//
baseVal = builder->emitFieldExtract(
taggedUnionInfo->restFieldType,
baseVal,
taggedUnionInfo->restFieldKey);
offset -= bulkSize;
}
// We've now extracted a field that could be either a scalar or
// a vector, and we have an offset into it. In the case where
// the base value is a vector, we will extract out the appropriate
// element.
//
if( auto baseVecType = as<IRVectorType>(baseVal->getDataType()) )
{
size_t vecElementSize = 4;
auto index = offset / vecElementSize;
baseVal = builder->emitElementExtract(
baseVecType->getElementType(),
baseVal,
builder->getIntValue(builder->getIntType(), index));
offset -= index*vecElementSize;
}
// At this point, our `baseVal` should be a single `uint`, and
// it should provide the storage for the exact thing we wanted
// to access (under the assumption that we always fetch 4 bytes
// on 4-byte alignment).
//
IRInst* payloadVal = baseVal;
SLANG_ASSERT(offset == 0);
// TODO: we could imagine adding logic here to handle types less
// than 4 bytes in size by shifting and masking the value we
// just loaded.
// The payload field we were trying to extract might have a type
// other than `uint`, and to handle that case we need to employ
// a bit-cast to get to the desired type.
//
if( payloadVal->getDataType() != payloadType )
{
payloadVal = builder->emitBitCast(
payloadType,
payloadVal);
}
return payloadVal;
}
}
// All of the logic so far as assumed we can just call `getTaggedUnionInfo`
// and have easy access to all the required information and the
// synthesized replacement type.
//
TaggedUnionInfo* getTaggedUnionInfo(IRType* type)
{
// The big picture is fairly simple: we will lazily build and
// memoize the information about tagged unions.
//
{
TaggedUnionInfo* info = nullptr;
if(mapIRTypeToTaggedUnionInfo.TryGetValue(type, info))
return info;
}
// When we don't find information in our memo-cache, we
// will construct it and add it to both the memo-cache
// *and* a global list of all tagged unions encountered,
// so that we can replacement them later.
//
auto info = createTaggedUnionInfo(type);
mapIRTypeToTaggedUnionInfo.Add(type, info.Ptr());
taggedUnionInfos.add(info);
return info;
}
// The actual logic for creating a `TaggedUnionInfo` is relatively
// straightforward once we've decided what information we need.
//
RefPtr<TaggedUnionInfo> createTaggedUnionInfo(IRType* type)
{
// We expect that any type used as an operation to one of the
// `extractTaggedUnion*` operations must be an IR tagged union.
//
// Note: If/when we ever expose `union`s to user and allow
// then to create *generic* tagged union types it might appear
// that this needs to be changed to account for a `specialize`
// instruction in place of a concrete tagged union, but in
// practice this pass needs to be performed late enough that
// any such generic should be fully specialized.
//
auto taggedUnionType = as<IRTaggedUnionType>(type);
SLANG_ASSERT(taggedUnionType);
RefPtr<TaggedUnionInfo> info = new TaggedUnionInfo();
info->taggedUnionType = taggedUnionType;
// We are going to create an instruction to replace `type`,
// and thus will be placing it into the same parent.
//
auto builder = getBuilder();
builder->setInsertBefore(type);
// A tagged union type will be replaced with an ordinary
// `struct` type with fields to store all the relevant
// data from any of the cases, plus a tag field.
//
auto structType = builder->createStructType();
info->replacementInst = structType;
// We require/expect the earlier code generation steps to have
// associated a layout with every tagged union that appears in
// the code.
//
auto layoutDecoration = type->findDecoration<IRLayoutDecoration>();
SLANG_ASSERT(layoutDecoration);
auto layout = layoutDecoration->getLayout();
SLANG_ASSERT(layout);
auto taggedUnionTypeLayout = as<IRTaggedUnionTypeLayout>(layout);
SLANG_ASSERT(taggedUnionTypeLayout);
info->taggedUnionTypeLayout = taggedUnionTypeLayout;
// The size of the "payload" for the different cases (everything but
// the tag) is taken to be the offset of the tag itself.
//
// TODO: this might be inaccurate if the payload size isn't a multiple
// of the tag's alignment. We should deal with that when/if we support
// types smaller than 4 bytes in unions.
//
auto payloadSize = taggedUnionTypeLayout->getTagOffset().getFiniteValue();
// We are going to be construction IR code that makes use of the `int`
// and `uint` types in several cases, so we go ahead and get a pointer
// to those types here.
//
auto intType = getBuilder()->getIntType();
auto uintType = getBuilder()->getBasicType(BaseType::UInt);
// For now we will use a simple stragegy for how we encode a union,
// which depends only on the total number of bytes needed, and not
// on the makeup of the values being stored.
//
// We will start by allocating one or more `uint4` values (in an
// array for the "or more" case) to hold the bulk of any large
// payload value.
//
size_t bulkVectorSize = 16; // Note: assuming `sizeof(uint4) == 16` on all targets
auto bulkVectorCount = payloadSize / bulkVectorSize;
auto bulkFieldSize = bulkVectorCount * bulkVectorSize;
if( bulkVectorCount )
{
IRType* bulkFieldType = builder->getVectorType(
uintType,
builder->getIntValue(intType, 4));
if( bulkVectorCount > 1 )
{
bulkFieldType = builder->getArrayType(
bulkFieldType,
builder->getIntValue(intType, bulkVectorCount));
}
auto bulkFieldKey = builder->createStructKey();
builder->createStructField(structType, bulkFieldKey, bulkFieldType);
info->bulkFieldKey = bulkFieldKey;
info->bulkFieldType = bulkFieldType;
}
info->bulkSize = bulkFieldSize;
// The rest of the data (anything that doesn't fit in the bulk field),
// will get allocated into a single scalar or vector of `uint`.
//
auto restSize = payloadSize - bulkFieldSize;
if( restSize )
{
size_t restElementSize = 4; // assuming `sizeof(uint) == 4` on all targets
auto restElementCount = restSize / restElementSize;
auto restFieldSize = restElementSize * restElementCount;
SLANG_ASSERT(restFieldSize == restSize); // Note: all our current targets have minimum 4-byte storage granularity
IRType* restFieldType = uintType;
if( restElementCount > 1 )
{
restFieldType = builder->getVectorType(
restFieldType,
builder->getIntValue(intType, restElementCount));
}
auto restFieldKey = builder->createStructKey();
builder->createStructField(structType, restFieldKey, restFieldType);
info->restFieldKey = restFieldKey;
info->restFieldType = restFieldType;
info->restSize = restFieldSize;
}
// Finally, we add a field to represent the tag.
//
auto tagFieldType = uintType;
auto tagFieldKey = builder->createStructKey();
builder->createStructField(structType, tagFieldKey, tagFieldType);
info->tagFieldKey = tagFieldKey;
return info;
}
};
void desugarUnionTypes(
IRModule* module)
{
DesugarUnionTypesContext context;
context.module = module;
context.processModule();
}
} // namespace Slang