https://github.com/shader-slang/slang
Raw File
Tip revision: 01efe34dbef2be952298075abd8d36cc67ac9f4e authored by Yong He on 04 March 2024, 21:14:21 UTC
Add `IGlobalSession::getSessionDescDigest`. (#3669)
Tip revision: 01efe34
slang-ir-sccp.cpp
// slang-ir-sccp.cpp
#include "slang-ir-sccp.h"

#include "slang-ir.h"
#include "slang-ir-insts.h"

namespace Slang {


// This file implements the Spare Conditional Constant Propagation (SCCP) optimization.
//
// We will apply the optimization over individual functions, so we will start with
// a context struct for the state that we will share across functions:
//
struct SharedSCCPContext
{
    IRModule*       module;
    DiagnosticSink* sink;
};
//
// Next we have a context struct that will be applied for each function (or other
// code-bearing value) that we optimize:
//
struct SCCPContext
{
    SharedSCCPContext*      shared; // shared state across functions
    IRGlobalValueWithCode*  code;   // the function/code we are optimizing

    // The SCCP algorithm applies abstract interpretation to the code of the
    // function using a "lattice" of values. We can think of a node on the
    // lattice as representing a set of values that a given instruction
    // might take on.
    //
    struct LatticeVal
    {
        // We will use three "flavors" of values on our lattice.
        //
        enum class Flavor
        {
            // The `None` flavor represent an empty set of values, meaning
            // that we've never seen any indication that the instruction
            // produces a (well-defined) value. This could indicate an
            // instruction that does not appear to execute, but it could
            // also indicate an instruction that we know invokes undefined
            // behavior, so we can freely pick a value for it on a whim.
            None,

            // The `Constant` flavor represents an instuction that we
            // have only ever seen produce a single, fixed value. It's
            // `value` field will hold that constant value.
            Constant,

            // The `Any` flavor represents an instruction that might produce
            // different values at runtime, so we go ahead and approximate
            // this as it potentially yielding any value whatsoever. A
            // more precise analysis could use sets or intervals of values,
            // but for SCCP anything that could take on more than 1 value
            // at runtime is assumed to be able to take on *any* value.
            Any,
        };

        // The flavor of this value (`None`, `Constant`, or `Any`)
        Flavor  flavor;

        // If this is a `Constant` lattice value, then this field
        // points to the IR instruction that defines the actual constant value.
        // For all other flavors it should be null.
        IRInst* value = nullptr;

        // For convenience, we define `static` factory functions to
        // produce values of each of the flavors.

        static LatticeVal getNone()
        {
            LatticeVal result;
            result.flavor = Flavor::None;
            return result;
        }

        static LatticeVal getAny()
        {
            LatticeVal result;
            result.flavor = Flavor::Any;
            return result;
        }

        static LatticeVal getConstant(IRInst* value)
        {
            LatticeVal result;
            result.flavor = Flavor::Constant;
            result.value = value;
            return result;
        }

        // We also need to be able to test if two lattice
        // values are equal, so that we can avoid updating
        // downstream dependencies if our knowledge about
        // an instruction hasn't actually changed.
        //
        bool operator==(LatticeVal const& that)
        {
            return this->flavor == that.flavor
                && this->value == that.value;
        }

        bool operator!=(LatticeVal const& that)
        {
            return !( *this == that );
        }
    };

    static bool isEvaluableOpCode(IROp op)
    {
        switch (op)
        {
        case kIROp_IntLit:
        case kIROp_BoolLit:
        case kIROp_FloatLit:
        case kIROp_StringLit:
        case kIROp_Add:
        case kIROp_Sub:
        case kIROp_Mul:
        case kIROp_Div:
        case kIROp_Neg:
        case kIROp_Not:
        case kIROp_Eql:
        case kIROp_Neq:
        case kIROp_Leq:
        case kIROp_Geq:
        case kIROp_Less:
        case kIROp_Greater:
        case kIROp_Lsh:
        case kIROp_Rsh:
        case kIROp_BitAnd:
        case kIROp_BitOr:
        case kIROp_BitXor:
        case kIROp_BitNot:
        case kIROp_BitCast:
        case kIROp_CastIntToFloat:
        case kIROp_CastFloatToInt:
        case kIROp_IntCast:
        case kIROp_FloatCast:
        case kIROp_Select:
            return true;
        default:
            return false;
        }
    }

    // If we imagine a variable (actually an SSA phi node...) that
    // might be assigned lattice value A at one point in the code,
    // and lattice value B at another point, we need a way to
    // combine these to form our knowledge of the possible value(s)
    // for the variable.
    //
    // In terms of computation on a lattice, we want the "meet"
    // operation, which computes the lower bound on what we know.
    // If we interpret our lattice values as sets, then we are
    // trying to compute the union.
    //
    LatticeVal meet(LatticeVal const& left, LatticeVal const& right)
    {
        // If either value is `None` (the empty set), then the union
        // will be the other value.
        //
        if(left.flavor == LatticeVal::Flavor::None) return right;
        if(right.flavor == LatticeVal::Flavor::None) return left;

        // If either value is `Any` (the universal set), then
        // the union is also the universal set.
        //
        if(left.flavor == LatticeVal::Flavor::Any) return LatticeVal::getAny();
        if(right.flavor == LatticeVal::Flavor::Any) return LatticeVal::getAny();

        // At this point we've ruled out the case where either value
        // is `None` *or* `Any`, so we can assume both values are
        // `Constant`s.
        SLANG_ASSERT(left.flavor == LatticeVal::Flavor::Constant);
        //
        SLANG_ASSERT(right.flavor == LatticeVal::Flavor::Constant);

        // If the two lattice values represent the *same* constant value
        // (they are the same singleton set) then the union is that
        // singleton set as well.
        //
        // TODO: This comparison assumes that constants with
        // the same value with be represented with the
        // same instruction, which is not *always*
        // guaranteed in the IR today.
        //
        if(left.value == right.value)
            return left;

        // Otherwise, we have two distinct singleton sets, and their
        // union should be a set with two elements. We can't represent
        // that on the lattice for SCCP, so the proper lower bound
        // is the universal set (`Any`)
        //
        return LatticeVal::getAny();
    }

    // During the execution of the SCCP algorithm, we will track our best
    // "estimate" so far of the set of values each instruction could take
    // on. This amounts to a mapping from IR instructions to lattice values,
    // where any instruction not present in the map is assumed to default
    // to the `None` case (the empty set)
    //
    Dictionary<IRInst*, LatticeVal> mapInstToLatticeVal;

    // Updating the lattice value for an instruction is easy, but we'll
    // use a simple function to make our intention clear.
    //
    void setLatticeVal(IRInst* inst, LatticeVal const& val)
    {
        mapInstToLatticeVal[inst] = val;
    }

    // Querying the lattice value for an instruction isn't *just* a matter
    // of looking it up in the dictionary, because we need to account for
    // cases of lattice values that might come from outside the current
    // function.
    //
    LatticeVal getLatticeVal(IRInst* inst)
    {
        // Instructions that represent constant values should always
        // have a lattice value that reflects this.
        //
        switch( inst->getOp() )
        {
        case kIROp_IntLit:
        case kIROp_FloatLit:
        case kIROp_StringLit:
        case kIROp_BoolLit:
            return LatticeVal::getConstant(inst);
            break;

        // TODO: We might want to start having support for constant
        // values of aggregate types (e.g., a `makeArray` or `makeStruct`
        // where all the operands are constant is itself a constant).

        default:
            break;
        }

        // Look up in the dictionary and just return the value we get from it.
        LatticeVal latticeVal;
        if(mapInstToLatticeVal.tryGetValue(inst, latticeVal))
            return latticeVal;

        // If we can't find the value from dictionary, we want to return None if this is a value
        // in the same function as the one we are working with right now. If it is defined
        // elsewhere, we return Any.
        auto parentBlock = as<IRBlock>(inst->getParent());
        bool isProcessingGlobalScope = (code == nullptr);
        if (!parentBlock && isProcessingGlobalScope)
        {
            // We are folding constant in the global scope, continue registering the inst as Any.
        }
        else
        {
            // If we are processing a function and asked for the lattice value of an instruction
            // not contained in the current function, we will treat it as having potentially any
            // value, rather than the default of none.
            //
            if(!parentBlock || parentBlock->getParent() != code) return LatticeVal::getAny();
        }

        return LatticeVal::getNone();
    }

    // Along the way we might need to create new IR instructions
    // to represnet new constant values we find, or new control
    // flow instructiosn when we start simplifying things.
    //
    IRBuilder builderStorage;
    IRBuilder* getBuilder() { return &builderStorage; }

    // LatticeVal constant evaluation methods.
#define SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v) \
    switch (v.flavor)                       \
    {                                       \
    case LatticeVal::Flavor::None:          \
        return LatticeVal::getNone();       \
    case LatticeVal::Flavor::Any:           \
        return LatticeVal::getAny();        \
    default:                                \
        break;                              \
    }

    LatticeVal evalCast(IRType* type, LatticeVal v0)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto irConstant = as<IRConstant>(v0.value);

        IRInst* resultVal = nullptr;
        if (type->getOp() == irConstant->getOp())
            return LatticeVal::getConstant(irConstant);

        switch (type->getOp())
        {
        case kIROp_Int8Type:
        case kIROp_Int16Type:
        case kIROp_IntType:
        case kIROp_Int64Type:
        case kIROp_UInt8Type:
        case kIROp_UInt16Type:
        case kIROp_UIntType:
        case kIROp_UInt64Type:
        case kIROp_IntPtrType:
        case kIROp_UIntPtrType:
            switch (irConstant->getOp())
            {
            case kIROp_FloatLit:
                resultVal =
                    getBuilder()->getIntValue(type, (IRIntegerValue)irConstant->value.floatVal);
                break;
            case kIROp_IntLit:
            case kIROp_BoolLit:
                {
                    IRIntegerValue intVal = irConstant->value.intVal;
                    resultVal = getBuilder()->getIntValue(type, (IRIntegerValue)intVal);
                }
                break;
            default:
                return LatticeVal::getAny();
            }
            break;
        case kIROp_FloatType:
        case kIROp_DoubleType:
        case kIROp_HalfType:
            switch (irConstant->getOp())
            {
            case kIROp_FloatLit:
                resultVal = getBuilder()->getFloatValue(
                    type, (IRFloatingPointValue)irConstant->value.floatVal);
                break;
            case kIROp_IntLit:
            case kIROp_BoolLit:
                resultVal = getBuilder()->getFloatValue(
                    type, (IRFloatingPointValue)irConstant->value.intVal);
                break;
            default:
                return LatticeVal::getAny();
            }
            break;
        case kIROp_BoolType:
            switch (irConstant->getOp())
            {
            case kIROp_FloatLit:
                resultVal = getBuilder()->getBoolValue(irConstant->value.floatVal != 0);
                break;
            case kIROp_IntLit:
            case kIROp_BoolLit:
                {
                    resultVal = getBuilder()->getBoolValue(irConstant->value.intVal != 0);
                }
                break;
            default:
                return LatticeVal::getAny();
            }
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    LatticeVal evalDefaultConstruct(IRType* type)
    {
        IRInst* resultVal = nullptr;
        switch (type->getOp())
        {
        case kIROp_Int8Type:
        case kIROp_Int16Type:
        case kIROp_IntType:
        case kIROp_Int64Type:
        case kIROp_IntPtrType:
        case kIROp_UInt8Type:
        case kIROp_UInt16Type:
        case kIROp_UIntType:
        case kIROp_UInt64Type:
        case kIROp_UIntPtrType:
            resultVal = getBuilder()->getIntValue(type, (IRIntegerValue)0);
            break;

        case kIROp_FloatType:
        case kIROp_DoubleType:
        case kIROp_HalfType:
            resultVal = getBuilder()->getFloatValue(type, (IRFloatingPointValue)0.0);
            break;

        case kIROp_BoolType:
            resultVal = getBuilder()->getBoolValue(false);
            break;
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    template<typename TIntFunc, typename TFloatFunc>
    LatticeVal evalBinaryImpl(
        IRType* type,
        LatticeVal v0,
        LatticeVal v1,
        const TIntFunc& intFunc,
        const TFloatFunc& floatFunc)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto c0 = as<IRConstant>(v0.value);
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
        auto c1 = as<IRConstant>(v1.value);
        IRInst* resultVal = nullptr;
        switch (type->getOp())
        {
        case kIROp_Int8Type:
        case kIROp_Int16Type:
        case kIROp_IntType:
        case kIROp_Int64Type:
        case kIROp_UInt8Type:
        case kIROp_UInt16Type:
        case kIROp_UIntType:
        case kIROp_UInt64Type:
        case kIROp_IntPtrType:
        case kIROp_UIntPtrType:
        case kIROp_BoolType:
            resultVal = getBuilder()->getIntValue(type, intFunc(c0->value.intVal, c1->value.intVal));
            break;
        case kIROp_FloatType:
        case kIROp_DoubleType:
        case kIROp_HalfType:
            resultVal = getBuilder()->getFloatValue(type, floatFunc(c0->value.floatVal, c1->value.floatVal));
            break;
        default:
            break;
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    template <typename TIntFunc>
    LatticeVal evalBinaryIntImpl(
        IRType* type,
        LatticeVal v0,
        LatticeVal v1,
        const TIntFunc& intFunc)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto c0 = as<IRConstant>(v0.value);
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
        auto c1 = as<IRConstant>(v1.value);
        IRInst* resultVal = nullptr;
        switch (type->getOp())
        {
        case kIROp_Int8Type:
        case kIROp_Int16Type:
        case kIROp_IntType:
        case kIROp_Int64Type:
        case kIROp_IntPtrType:
        case kIROp_UInt8Type:
        case kIROp_UInt16Type:
        case kIROp_UIntType:
        case kIROp_UInt64Type:
        case kIROp_UIntPtrType:
        case kIROp_BoolType:
            resultVal =
                getBuilder()->getIntValue(type, intFunc(c0->value.intVal, c1->value.intVal));
            break;
        default:
            break;
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    template <typename TIntFunc>
    LatticeVal evalUnaryIntImpl(
        IRType* type, LatticeVal v0, const TIntFunc& intFunc)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto c0 = as<IRConstant>(v0.value);
        IRInst* resultVal = nullptr;
        switch (type->getOp())
        {
        case kIROp_Int8Type:
        case kIROp_Int16Type:
        case kIROp_IntType:
        case kIROp_Int64Type:
        case kIROp_IntPtrType:
        case kIROp_UInt8Type:
        case kIROp_UInt16Type:
        case kIROp_UIntType:
        case kIROp_UInt64Type:
        case kIROp_UIntPtrType:
        case kIROp_BoolType:
            resultVal =
                getBuilder()->getIntValue(type, intFunc(c0->value.intVal));
            break;
        default:
            break;
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    template <typename TIntFunc, typename TFloatFunc>
    LatticeVal evalComparisonImpl(
        IRType* type,
        LatticeVal v0,
        LatticeVal v1,
        const TIntFunc& intFunc,
        const TFloatFunc& floatFunc)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto c0 = as<IRConstant>(v0.value);
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v1)
        auto c1 = as<IRConstant>(v1.value);
        IRInst* resultVal = nullptr;
        switch (type->getOp())
        {
        case kIROp_Int8Type:
        case kIROp_Int16Type:
        case kIROp_IntType:
        case kIROp_Int64Type:
        case kIROp_IntPtrType:
        case kIROp_UInt8Type:
        case kIROp_UInt16Type:
        case kIROp_UIntType:
        case kIROp_UInt64Type:
        case kIROp_UIntPtrType:
        case kIROp_BoolType:
            resultVal =
                getBuilder()->getBoolValue(intFunc(c0->value.intVal, c1->value.intVal));
            break;
        case kIROp_FloatType:
        case kIROp_DoubleType:
        case kIROp_HalfType:
            resultVal =
                getBuilder()->getBoolValue(floatFunc(c0->value.floatVal, c1->value.floatVal));
            break;
        default:
            break;
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    LatticeVal evalAdd(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 + c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 + c1; });
    }
    LatticeVal evalSub(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 - c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 - c1; });
    }
    LatticeVal evalMul(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 * c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 * c1; });
    }
    LatticeVal evalDiv(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) {  return c0 / c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 / c1; });
    }
    LatticeVal evalEql(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalComparisonImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 == c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 == c1; });
    }
    LatticeVal evalNeq(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalComparisonImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 != c1; });
    }
    LatticeVal evalGeq(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalComparisonImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 >= c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 >= c1; });
    }
    LatticeVal evalLeq(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalComparisonImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 <= c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 <= c1; });
    }
    LatticeVal evalGreater(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalComparisonImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 > c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 > c1; });
    }
    LatticeVal evalLess(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalComparisonImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 < c1; },
            [](IRFloatingPointValue c0, IRFloatingPointValue c1) { return c0 < c1; });
    }
    LatticeVal evalAnd(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryIntImpl(
            type,
            v0,
            v1,
            [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != 0 && c1 != 0; });
    }
    LatticeVal evalOr(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryIntImpl(
            type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 != 0 || c1 != 0; });
    }
    LatticeVal evalNot(IRType* type, LatticeVal v0)
    {
        return evalUnaryIntImpl(type, v0, [](IRIntegerValue c0) { return c0 == 0; });
    }
    LatticeVal evalBitAnd(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryIntImpl(
            type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 & c1; });
    }
    LatticeVal evalBitOr(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryIntImpl(
            type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 | c1; });
    }
    LatticeVal evalBitNot(IRType* type, LatticeVal v0)
    {
        return evalUnaryIntImpl(type, v0, [](IRIntegerValue c0) { return ~c0; });
    }
    LatticeVal evalBitXor(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryIntImpl(
            type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 ^ c1; });
    }
    LatticeVal evalLsh(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryIntImpl(
            type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 << c1; });
    }
    LatticeVal evalRsh(IRType* type, LatticeVal v0, LatticeVal v1)
    {
        return evalBinaryIntImpl(
            type, v0, v1, [](IRIntegerValue c0, IRIntegerValue c1) { return c0 >> c1; });
    }
    LatticeVal evalNeg(IRType* type, LatticeVal v0)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto c0 = as<IRConstant>(v0.value);
        IRInst* resultVal = nullptr;
        switch (type->getOp())
        {
        case kIROp_Int8Type:
        case kIROp_Int16Type:
        case kIROp_IntType:
        case kIROp_Int64Type:
        case kIROp_IntPtrType:
        case kIROp_UInt8Type:
        case kIROp_UInt16Type:
        case kIROp_UIntType:
        case kIROp_UInt64Type:
        case kIROp_UIntPtrType:
            resultVal = getBuilder()->getIntValue(type, -c0->value.intVal);
            break;
        case kIROp_FloatType:
        case kIROp_DoubleType:
        case kIROp_HalfType:
            resultVal = getBuilder()->getFloatValue(type, -c0->value.floatVal);
            break;
        default:
            break;
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    LatticeVal evalBitCast(IRType* type, LatticeVal v0)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto c0 = as<IRConstant>(v0.value);

        uint64_t sourceValueBits = 0;
        switch (c0->getDataType()->getOp())
        {
        case kIROp_FloatType:
        {
            float fval = (float)c0->value.floatVal;
            memcpy(&sourceValueBits, &fval, sizeof(fval));
            break;
        }
        case kIROp_DoubleType:
        {
            double dval = c0->value.floatVal;
            memcpy(&sourceValueBits, &dval, sizeof(dval));
            break;
        }
        case kIROp_BoolType:
        {
            sourceValueBits = c0->value.intVal;
            break;
        }
        default:
            if (isIntegralType(c0->getDataType()))
            {
                sourceValueBits = c0->value.intVal;
            }
            else
            {
                return LatticeVal::getAny();
            }
            break;
        }

        IRInst* resultVal = nullptr;
        switch (type->getOp())
        {
        case kIROp_Int64Type:
        case kIROp_UInt64Type:
#if SLANG_PTR_IS_64
        case kIROp_IntPtrType:
        case kIROp_UIntPtrType:
#endif
            resultVal = getBuilder()->getIntValue(type, sourceValueBits);
            break;
        case kIROp_IntType:
        case kIROp_UIntType:
#if SLANG_PTR_IS_32
        case kIROp_IntPtrType:
        case kIROp_UIntPtrType:
#endif
            resultVal = getBuilder()->getIntValue(type, (uint32_t)sourceValueBits);
            break;
        case kIROp_FloatType:
            {
                uint32_t val = (uint32_t)sourceValueBits;
                float floatVal = IntAsFloat((int)val);
                resultVal = getBuilder()->getFloatValue(type, floatVal);
            }
            break;
        case kIROp_DoubleType:
            resultVal = getBuilder()->getFloatValue(type, Int64AsDouble(sourceValueBits));
            break;
        default:
            break;
        }
        if (!resultVal)
            return LatticeVal::getAny();
        return LatticeVal::getConstant(resultVal);
    }

    LatticeVal evalSelect(LatticeVal v0, LatticeVal v1, LatticeVal v2)
    {
        SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
        auto c0 = as<IRConstant>(v0.value);
        return c0->value.intVal != 0 ? v1 : v2;
    }

    // In order to perform constant folding, we need to be able to
    // interpret an instruction over the lattice values.
    //
    LatticeVal interpretOverLattice(IRInst* inst)
    {
        SLANG_UNUSED(inst);

        // Certain instruction always produce constants, and we
        // want to special-case them here.
        switch( inst->getOp() )
        {
        case kIROp_IntLit:
        case kIROp_FloatLit:
        case kIROp_StringLit:
        case kIROp_BoolLit:
            return LatticeVal::getConstant(inst);

        // We might also want to special-case certain
        // instructions where we shouldn't bother trying to
        // constant-fold them and should just default to the
        // `Any` value right away.
        case kIROp_Call:
        case kIROp_ByteAddressBufferLoad:
        case kIROp_ByteAddressBufferStore:
        case kIROp_Alloca:
        case kIROp_Store:
        case kIROp_Load:
            return LatticeVal::getAny();
        default:
            break;
        }

        // TODO: We should now look up the lattice values for
        // the operands of the instruction.
        //
        // If all of the operands have `Constant` lattice values,
        // then we can potential execute the operation directly
        // on those constant values, create a fresh `IRConstant`,
        // and return a `Constant` lattice value for it. This
        // would allow us to achieve true constant folding here.
        //
        // Textbook discussions of SCCP often point out that it
        // is also possible to perform certain algebraic simplifications
        // here, such as evaluating a multiply by a `Constant` zero
        // to zero.
        //
        // As a default, if any operand has the `Any` value
        // then the result of the operation should be treated as
        // `Any`. There are exceptions to this, however, with the
        // multiply-by-zero example being an important example.
        // If we had previously decided that (Any * None) -> Any
        // but then we refine our estimates and have (Any * Constant(0)) -> Constant(0)
        // then we have violated the monotonicity rules for how
        // our values move through the lattice, and we may break
        // the convergence guarantees of the analysis.
        //
        // When we have a mix of `None` and `Constant` operands,
        // then the `None` values imply that our operation is using
        // uninitialized data or the results of undefined behavior.
        // We could try to propagate the `None` through, and allow
        // the compiler to speculatively assume that the operation
        // produces whatever value we find convenient. Alternatively,
        // we can be less aggressive and treat an operation with
        // `None` inputs as producing `Any` to make sure we don't
        // optimize the code based on non-obvious assumptions.
        //
        // For now we implement only basic folding operations for
        // scalar values.
        if (!as<IRBasicType>(inst->getDataType()))
            return LatticeVal::getAny();

        switch (inst->getOp())
        {
        case kIROp_IntCast:
        case kIROp_FloatCast:
        case kIROp_CastIntToFloat:
        case kIROp_CastFloatToInt:
            switch (inst->getOperandCount())
            {
            case 1:
                return evalCast(inst->getDataType(), getLatticeVal(inst->getOperand(0)));

            default:
                return LatticeVal::getAny();
            }
        case kIROp_DefaultConstruct:
            return evalDefaultConstruct(inst->getDataType());
        case kIROp_Add:
            return evalAdd(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Sub:
            return evalSub(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Mul:
            return evalMul(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Div:
        {
            // Detect divide by zero error.
            auto divisor = getLatticeVal(inst->getOperand(1));
            if (divisor.flavor == LatticeVal::Flavor::Constant)
            {
                if (isIntegralType(divisor.value->getDataType()))
                {
                    auto c = as<IRConstant>(divisor.value);
                    if (c->value.intVal == 0)
                    {
                        if (shared->sink)
                            shared->sink->diagnose(inst->sourceLoc, Diagnostics::divideByZero);
                        return LatticeVal::getAny();
                    }
                }
            }
            return evalDiv(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                divisor);
        }
        case kIROp_Eql:
            return evalEql(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Neq:
            return evalNeq(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Greater:
            return evalGreater(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Less:
            return evalLess(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Leq:
            return evalLeq(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Geq:
            return evalGeq(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_And:
            return evalAnd(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Or:
            return evalOr(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Not:
            return evalNot(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
        case kIROp_BitAnd:
            return evalBitAnd(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_BitOr:
            return evalBitOr(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_BitNot:
            return evalBitNot(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
        case kIROp_BitXor:
            return evalBitXor(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_BitCast:
            return evalBitCast(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
        case kIROp_Neg:
            return evalNeg(inst->getDataType(), getLatticeVal(inst->getOperand(0)));
        case kIROp_Lsh:
            return evalLsh(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Rsh:
            return evalRsh(
                inst->getDataType(),
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)));
        case kIROp_Select:
            return evalSelect(
                getLatticeVal(inst->getOperand(0)),
                getLatticeVal(inst->getOperand(1)),
                getLatticeVal(inst->getOperand(2)));
        default:
            break;
        }

        // A safe default is to assume that every instruction not
        // handled by one of the cases above could produce *any*
        // value whatsoever.
        return LatticeVal::getAny();
    }


    // For basic blocks, we will do tracking very similar to what we do for
    // ordinary instructions, just with a simpler lattice: every block
    // will either be marked as "never executed" or in a "possibly executed"
    // state. We track this as a set of the blocks that have been
    // marked as possibly executed, plus a getter and setter function.

    HashSet<IRBlock*> executedBlocks;

    bool isMarkedAsExecuted(IRBlock* block)
    {
        return executedBlocks.contains(block);
    }

    void markAsExecuted(IRBlock* block)
    {
        executedBlocks.add(block);
    }

    // The core of the algorithm is based on two work lists.
    // One list holds CFG nodes (basic blocks) that we have
    // discovered might execute, and thus need to be processed,
    // and the other holds SSA nodes (instructions) that need
    // their "estimated" value to be updated.

    List<IRBlock*>  cfgWorkList;
    List<IRInst*>   ssaWorkList;

    // A key operation is to take an IR instruction and update
    // its "estimated" value on the lattice. This might happen when
    // we first discover the instruction could be executed, or
    // when we discover that one or more of its operands has
    // changed its lattice value so that we need to update our estimate.
    //
    void updateValueForInst(IRInst* inst)
    {
        // Block parameters are conceptually SSA "phi nodes", and it
        // doesn't make sense to update their values here, because the
        // actual candidate values for them comes from the predecessor blocks
        // that provide arguments. We will see that logic shortly, when
        // handling `IRUnconditionalBranch`.
        //
        if(as<IRParam, IRDynamicCastBehavior::NoUnwrap>(inst))
            return;

        // We want to special-case terminator instructions here,
        // since abstract interpretation of them should cause blocks to
        // be marked as executed, etc.
        //
        if( const auto terminator = as<IRTerminatorInst>(inst) )
        {
            if( auto unconditionalBranch = as<IRUnconditionalBranch>(inst) )
            {
                // When our abstract interpreter "executes" an unconditional
                // branch, it needs to mark the target block as potentially
                // executed. We do this by adding the target to our CFG work list.
                //
                auto target = unconditionalBranch->getTargetBlock();
                cfgWorkList.add(target);

                // Besides transferring control to another block, the other
                // thing our unconditional branch instructions do is provide
                // the arguments for phi nodes in the target block.
                // We thus need to interpret each argument on the branch
                // instruction like an "assignment" to the corresponding
                // parameter of the target block.
                //
                UInt argCount = unconditionalBranch->getArgCount();
                IRParam* pp = target->getFirstParam();
                for( UInt aa = 0; aa < argCount; ++aa, pp = pp->getNextParam() )
                {
                    IRInst* arg = unconditionalBranch->getArg(aa);
                    IRInst* param = pp;

                    // We expect the number of arguments and parameters to match,
                    // or else the IR is violating its own invariants.
                    //
                    SLANG_ASSERT(param);

                    // We will update the value for the target block's parameter
                    // using our "meet" operation (union of sets of possible values)
                    //
                    LatticeVal oldVal = getLatticeVal(param);

                    // If we've already determined that the block parameter could
                    // have any value whatsoever, there is no reason to bother
                    // updating it.
                    //
                    if(oldVal.flavor == LatticeVal::Flavor::Any)
                        continue;

                    // We can look up the lattice value for the argument,
                    // because we should have interpreted it already
                    //
                    LatticeVal argVal = getLatticeVal(arg);

                    // Now we apply the meet operation and see if the value changed.
                    //
                    LatticeVal newVal = meet(oldVal, argVal);
                    if( newVal != oldVal )
                    {
                        // If the "estimated" value for the parameter has changed,
                        // then we need to update it in our dictionary, and then
                        // make sure that all of the users of the parameter get
                        // their estimates updated as well.
                        //
                        setLatticeVal(param, newVal);
                        for( auto use = param->firstUse; use; use = use->nextUse )
                        {
                            ssaWorkList.add(use->getUser());
                        }
                    }
                }
            }
            else if( auto conditionalBranch = as<IRConditionalBranch>(inst) )
            {
                // An `IRConditionalBranch` is used for two-way branches.
                // We will look at the lattice value for the condition,
                // to see if we can narrow down which of the two ways
                // might actually be taken.
                //
                auto condVal = getLatticeVal(conditionalBranch->getCondition());

                // We do not expect to see a `None` value here, because that
                // would mean the user is branching based on an undefined
                // value.
                //
                // TODO: We should make sure there is no way for the user
                // to trigger this assert with bad code that involves
                // uninitialized variables. Right now we don't special
                // case the `undefined` instruction when computing lattice
                // values, so it shouldn't be a problem.
                //
                SLANG_ASSERT(condVal.flavor != LatticeVal::Flavor::None);

                // If the branch condition is a constant, we expect it to
                // be a Boolean constant. We won't assert that it is the
                // case here, just to be defensive.
                //
                if( condVal.flavor == LatticeVal::Flavor::Constant )
                {
                    if( auto boolConst = as<IRBoolLit>(condVal.value) )
                    {
                        // Only one of the two targe blocks is possible to
                        // execute, based on what we know of the condition,
                        // so we will add that target to our work list and
                        // bail out now.
                        //
                        auto target = boolConst->getValue() ? conditionalBranch->getTrueBlock() : conditionalBranch->getFalseBlock();
                        cfgWorkList.add(target);
                        return;
                    }
                }

                // As a fallback, if the condition isn't constant
                // (or somehow wasn't a Boolean constnat), we will
                // assume that either side of the branch could be
                // taken, so that both of the target blocks are
                // potentially executed.
                //
                cfgWorkList.add(conditionalBranch->getTrueBlock());
                cfgWorkList.add(conditionalBranch->getFalseBlock());
            }
            else if( auto switchInst = as<IRSwitch>(inst) )
            {
                // The handling of a `switch` instruction is similar to the
                // case for a two-way branch, with the main difference that
                // we have to deal with an integer condition value.

                auto condVal = getLatticeVal(switchInst->getCondition());
                SLANG_ASSERT(condVal.flavor != LatticeVal::Flavor::None);

                UInt caseCount = switchInst->getCaseCount();
                if( condVal.flavor == LatticeVal::Flavor::Constant )
                {
                    if( auto condConst = as<IRIntLit>(condVal.value) )
                    {
                        // At this point we have a constant integer condition
                        // value, and we just need to find the case (if any)
                        // that matches it. We will default to considering
                        // the `default` label as the target.
                        //
                        auto target = switchInst->getDefaultLabel();
                        for( UInt cc = 0; cc < caseCount; ++cc )
                        {
                            if( auto caseConst = as<IRIntLit>(switchInst->getCaseValue(cc)) )
                            {
                                if(caseConst->getValue() == condConst->getValue())
                                {
                                    target = switchInst->getCaseLabel(cc);
                                    break;
                                }
                            }
                        }

                        // Whatever single block we decided will get executed,
                        // we need to make sure it gets processed and then bail.
                        //
                        cfgWorkList.add(target);
                        return;
                    }
                }

                // The fallback is to assume that the `switch` instruction might
                // branch to any of its cases, or the `default` label.
                //
                for( UInt cc = 0; cc < caseCount; ++cc )
                {
                    cfgWorkList.add(switchInst->getCaseLabel(cc));
                }
                cfgWorkList.add(switchInst->getDefaultLabel());
            }
            else if (auto targetSwitch = as<IRTargetSwitch>(inst))
            {
                for (UInt cc = 0; cc < targetSwitch->getCaseCount(); ++cc)
                {
                    cfgWorkList.add(targetSwitch->getCaseBlock(cc));
                }
            }
            // There are other cases of terminator instructions not handled
            // above (e.g., `return` instructions), but these can't cause
            // additional basic blocks in the CFG to execute, so we don't
            // need to consider them here.
            //
            // No matter what, we are done with a terminator instruction
            // after inspecting it, and there is no reason we have to
            // try and compute its "value."
            return;
        }

        // For an "ordinary" instruction, we will first check what value
        // has been registered for it already.
        //
        LatticeVal oldVal = getLatticeVal(inst);

        // If we have previous decided that the instruction could take
        // on any value whatsoever, then any further update to our
        // guess can't expand things more, and so there is nothing to do.
        //
        if( oldVal.flavor == LatticeVal::Flavor::Any )
        {
            return;
        }

        // Otherwise, we compute a new guess at the value of
        // the instruction based on the lattice values of the
        // stuff it depends on.
        //
        LatticeVal newVal = interpretOverLattice(inst);

        // If nothing changed about our guess, then there is nothing
        // further to do, because users of this instruction have
        // already computed their guess based on its current value.
        //
        if(newVal == oldVal)
        {
            return;
        }

        // If the guess did change, then we want to register our
        // new guess as the lattice value for this instruction.
        //
        setLatticeVal(inst, newVal);

        // Next we iterate over all the users of this instruction
        // and add them to our work list so that we can update
        // their values based on the new information.
        //
        for( auto use = inst->firstUse; use; use = use->nextUse )
        {
            ssaWorkList.add(use->getUser());
        }
    }

    // Run the constant folding on global scope and specialized types only.
    bool applyOnGlobalScope(IRModule* module)
    {
        bool changed = applyOnScope(module->getModuleInst());
        for (auto child : module->getModuleInst()->getChildren())
        {
            switch (child->getOp())
            {
            case kIROp_StructType:
            case kIROp_ClassType:
            case kIROp_InterfaceType:
            case kIROp_WitnessTable:
                changed |= applyOnScope(child);
                break;
            }
        }
        return changed;
    }

    bool applyOnScope(IRInst* scopeInst)
    {
        builderStorage = IRBuilder(scopeInst);
        for (auto child : scopeInst->getChildren())
        {
            // Only consider evaluable opcodes.
            if (!isEvaluableOpCode(child->getOp()))
                continue;

            updateValueForInst(child);
        }
        while (ssaWorkList.getCount())
        {
            auto inst = ssaWorkList[0];
            ssaWorkList.fastRemoveAt(0);
            // Only consider evaluable opcodes and insts at global scope.
            if (!isEvaluableOpCode(inst->getOp()) || inst->getParent() != scopeInst)
                continue;
            updateValueForInst(inst);
        }

        bool changed = false;
        // Replace the insts with their values.
        List<IRInst*> instsToRemove;
        for (auto child : scopeInst->getChildren())
        {
            if (!isEvaluableOpCode(child->getOp()))
                continue;

            auto latticeVal = getLatticeVal(child);
            if (latticeVal.flavor == LatticeVal::Flavor::Constant && latticeVal.value != child)
            {
                child->replaceUsesWith(latticeVal.value);
                instsToRemove.add(child);
            }
        }

        if (instsToRemove.getCount())
        {
            changed = true;
            for (auto inst : instsToRemove)
                inst->removeAndDeallocate();
        }
        return changed;
    }

    // The `apply()` function will run the full algorithm.
    //
    bool apply()
    {
        bool changed = false;
        // We start with the busy-work of setting up our IR builder.
        //
        builderStorage = IRBuilder(shared->module);

        // We expect the caller to have filtered out functions with
        // no bodies, so there should always be at least one basic block.
        //
        auto firstBlock = code->getFirstBlock();
        SLANG_ASSERT(firstBlock);

        // The entry block is always going to be executed when the
        // function gets called, so we will process it right away.
        //
        cfgWorkList.add(firstBlock);

        // The parameters of the first block are our function parameters,
        // and we want to operate on the assumption that they could have
        // any value possible, so we will record that in our dictionary.
        //
        for( auto pp : firstBlock->getParams() )
        {
            setLatticeVal(pp, LatticeVal::getAny());
        }

        // Now we will iterate until both of our work lists go dry.
        //
        while(cfgWorkList.getCount() || ssaWorkList.getCount())
        {
            // Note: there is a design choice to be had here
            // around whether we do `if if` or `while while`
            // for these nested checks. The choice can affect
            // how long things take to converge.

            // We will start by processing any blocks that we
            // have determined are potentially reachable.
            //
            while( cfgWorkList.getCount() )
            {
                // We pop one block off of the work list.
                //
                auto block = cfgWorkList[0];
                cfgWorkList.fastRemoveAt(0);

                // We only want to process blocks that haven't
                // already been marked as executed, so that we
                // don't do redundant work.
                //
                if( !isMarkedAsExecuted(block) )
                {
                    // We should mark this new block as executed,
                    // so we can ignore it if it ever ends up on
                    // the work list again.
                    //
                    markAsExecuted(block);

                    // If the block is potentially executed, then
                    // that means the instructions in the block are too.
                    // We will walk through the block and update our
                    // guess at the value of each instruction, which
                    // may in turn add other blocks/instructions to
                    // the work lists.
                    //
                    for( auto inst : block->getDecorationsAndChildren() )
                    {
                        updateValueForInst(inst);
                    }
                }
            }

            // Once we've cleared the work list of blocks, we
            // will start looking at individual instructions that
            // need to be updated.
            //
            while( ssaWorkList.getCount() )
            {
                // We pop one instruction that needs an update.
                //
                auto inst = ssaWorkList[0];
                ssaWorkList.fastRemoveAt(0);

                // Before updating the instruction, we will check if
                // the parent block of the instructin is marked as
                // being executed. If it isn't, there is no reason
                // to update the value for the instruction, since
                // it might never be used anyway.
                //
                IRBlock* block = as<IRBlock>(inst->getParent());

                // It is possible that an instruction ended up on
                // our SSA work list because it is a user of an
                // instruction in a block of `code`, but it is not
                // itself an instruction a block of `code`.
                //
                // For example, if `code` is an `IRGeneric` that
                // yields a function, then `inst` might be an
                // instruction of that nested function, and not
                // an instruction of the generic itself.
                // Note that in such a case, the `inst` cannot
                // possible affect the values computed in the outer
                // generic, or the control-flow paths it might take,
                // so there is no reason to consider it.
                //
                // We guard against this case by only processing `inst`
                // if it is a child of a block in the current `code`.
                //
                if(!block || block->getParent() != code)
                    continue;

                if( isMarkedAsExecuted(block) )
                {
                    // If the instruction is potentially executed, we update
                    // its lattice value based on our abstraction interpretation.
                    //
                    updateValueForInst(inst);
                }
            }
        }

        // Once the work lists are empty, our "guesses" at the value
        // of different instructions and the potentially-executed-ness
        // of blocks should have converged to a conservative steady state.
        //
        // We are now equiped to start using the information we've gathered
        // to modify the code.

        // First, we will walk through all the code and replace instructions
        // with constants where it is possible.
        //
        List<IRInst*> instsToRemove;
        for( auto block : code->getBlocks() )
        {
            for( auto inst : block->getDecorationsAndChildren() )
            {
                // We look for instructions that have a constnat value on
                // the lattice.
                //
                LatticeVal latticeVal = getLatticeVal(inst);
                if(latticeVal.flavor != LatticeVal::Flavor::Constant)
                    continue;

                // As a small sanity check, we won't go replacing an
                // instruction with itself (this shouldn't really come
                // up, since constants are supposed to be at the global
                // scope right now)
                //
                IRInst* constantVal = latticeVal.value;
                if(constantVal == inst)
                    continue;

                // We replace any uses of the instruction with its
                // constant expected value, and add it to a list of
                // instructions to be removed *iff* the instruction
                // is known to have no obersvable side effects.
                //
                inst->replaceUsesWith(constantVal);
                if( !inst->mightHaveSideEffects() )
                {
                    // Don't delete phi parameters, they will be cleaned up in CFG simplification.
                    if (inst->getOp() != kIROp_Param)
                        instsToRemove.add(inst);
                }
            }
        }

        if (instsToRemove.getCount() != 0)
            changed = true;

        // Once we've replaced the uses of instructions that evaluate
        // to constants, we make a second pass to remove the instructions
        // themselves (or at least those without side effects).
        //
        for( auto inst : instsToRemove )
        {
            inst->removeAndDeallocate();
        }

        // Next we are going to walk through all of the terminator
        // instructions on blocks and look for ones that branch
        // based on a constant condition. These will be rewritten
        // to use direct branching instructions, which will of course
        // need to be emitted using a builder.
        //
        auto builder = getBuilder();
        for( auto block : code->getBlocks() )
        {
            auto terminator = block->getTerminator();

            // We check if we have a `switch` instruction with a constant
            // integer as its condition.
            //
            if( auto switchInst = as<IRSwitch>(terminator) )
            {
                if( auto constVal = as<IRIntLit>(switchInst->getCondition()) )
                {
                    // We will select the one branch that gets taken, based
                    // on the constant condition value. The `default` label
                    // will of course be taken if no `case` label matches.
                    //
                    IRBlock* target = switchInst->getDefaultLabel();
                    UInt caseCount = switchInst->getCaseCount();
                    for(UInt cc = 0; cc < caseCount; ++cc)
                    {
                        auto caseVal = switchInst->getCaseValue(cc);
                        if(auto caseConst = as<IRIntLit>(caseVal))
                        {
                            if( caseConst->getValue() == constVal->getValue() )
                            {
                                target = switchInst->getCaseLabel(cc);
                                break;
                            }
                        }
                    }

                    // Once we've found the target, we will emit a direct
                    // branch to it before the old terminator, and then remove
                    // the old terminator instruction.
                    //
                    builder->setInsertBefore(terminator);
                    builder->emitBranch(target);
                    terminator->removeAndDeallocate();
                    changed = true;
                }
            }
            else if(auto condBranchInst = as<IRConditionalBranch>(terminator))
            {
                if( auto constVal = as<IRBoolLit>(condBranchInst->getCondition()) )
                {
                    // The case for a two-sided conditional branch is similar
                    // to the `switch` case, but simpler.

                    IRBlock* target = constVal->getValue() ? condBranchInst->getTrueBlock() : condBranchInst->getFalseBlock();

                    builder->setInsertBefore(terminator);
                    builder->emitBranch(target);
                    terminator->removeAndDeallocate();
                    changed = true;
                }
            }
        }

        // At this point we've replaced some conditional branches
        // that would always go the same way (e.g., a `while(true)`),
        // which should render some of our blocks unreachable.
        // We will collect all those unreachable blocks into a list
        // of blocks to be removed, and then go about trying to
        // remove them.
        //
        List<IRBlock*> unreachableBlocks;
        for( auto block : code->getBlocks() )
        {
            if( !isMarkedAsExecuted(block) )
            {
                unreachableBlocks.add(block);
            }
        }
        //
        // It might seem like we could just do:
        //
        //      block->removeAndDeallocate();
        //
        // for each of the blocks in `unreachableBlocks`, but there
        // is a subtle point that has to be considered:
        //
        // We have a structured control-flow representation where
        // certain branching instructions name "join points" where
        // control flow logically re-converges. It is possible that
        // one of our unreachable blocks is still being used as
        // a join point.
        //
        // For example:
        //
        //      if(A)
        //          return B;
        //      else
        //          return C;
        //      D;
        //
        // In the above example, the block that computes `D` is
        // unreachable, but it is still the join point for the `if(A)`
        // branch.
        //
        // Rather than complicate the encoding of join points to
        // try to special-case an unreachable join point, we will
        // instead retain the join point as a block with only a single
        // `unreachable` instruction.
        //
        // To detect which blocks are unreachable and unreferenced,
        // we will check which blocks have any uses. Of course, it
        // might be that some of our unreachable blocks still reference
        // one another (e.g., an unreachable loop) so we will start
        // by removing the instructions from the bodies of our unreachable
        // blocks to eliminate any cross-references between them.
        //
        for( auto block : unreachableBlocks )
        {
            // TODO: In principle we could produce a diagnostic here
            // if any of these unreachable blocks appears to have
            // "non-trivial" code in it (that is, any code explicitly
            // written by the user, and not just code synthesized by
            // the compiler to satisfy language rules). Making that
            // determination could be tricky, so for now we will
            // err on the side of allowing unreachable code without
            // a warning.
            //
            block->removeAndDeallocateAllDecorationsAndChildren();
        }
        //
        // At this point every one of our unreachable blocks is empty,
        // and there should be no branches from reachable blocks
        // to unreachable ones.
        //
        // We will iterate over our unreachable blocks, and process
        // them differently based on whether they have any remaining uses.
        //
        for( auto block : unreachableBlocks )
        {
            // At this point there had better be no edges branching to
            // our block. We determined it was unreachable, so there had
            // better not be branches from reachable blocks to this one,
            // and all the unreachable blocks had their instructions
            // removed, so there should be no branches to it from other
            // unreachable blocks (or itself).
            //
            SLANG_ASSERT(block->getPredecessors().isEmpty());

            // If the block is completely unreferenced, we can safely
            // remove and deallocate it now.
            //
            if( !block->hasUses() )
            {
                block->removeAndDeallocate();
            }
            else
            {
                // Otherwise, the block has at least one use (but
                // no predecessors), which should indicate that it
                // is an unreachable join point.
                //
                // We will keep the block around, but its entire
                // body will consist of a single `unreachable`
                // instruction.
                //
                builder->setInsertInto(block);
                builder->emitUnreachable();
            }
        }
        return changed;
    }
};

static bool applySparseConditionalConstantPropagationRec(
    const SCCPContext&  globalContext,
    IRInst*             inst)
{
    bool changed = false;
    if( auto code = as<IRGlobalValueWithCode>(inst) )
    {
        if( code->getFirstBlock() )
        {
            SCCPContext context;
            context.shared = globalContext.shared;
            context.code = code;
            context.mapInstToLatticeVal = globalContext.mapInstToLatticeVal;
            changed |= context.apply();
        }
    }

    for( auto childInst : inst->getDecorationsAndChildren() )
    {
        switch (childInst->getOp())
        {
        case kIROp_Func:
        case kIROp_Block:
        case kIROp_Generic:
            break;
        default:
            // Skip other op codes.
            continue;
        }
        changed |= applySparseConditionalConstantPropagationRec(globalContext, childInst);
    }
    return changed;
}

bool applySparseConditionalConstantPropagation(
    IRModule*       module,
    DiagnosticSink* sink)
{
    if (sink && sink->getErrorCount())
        return false;

    SharedSCCPContext shared;
    shared.module = module;
    shared.sink = sink;

    // First we fold constants at global scope.
    SCCPContext globalContext;
    globalContext.shared = &shared;
    globalContext.code = nullptr;
    bool changed = globalContext.applyOnGlobalScope(module);

    // Now run recursive SCCP passes on each child code block.
    changed |= applySparseConditionalConstantPropagationRec(globalContext, module->getModuleInst());

    return changed;
}

bool applySparseConditionalConstantPropagationForGlobalScope(
    IRModule* module,
    DiagnosticSink* sink)
{
    if (sink && sink->getErrorCount())
        return false;

    SharedSCCPContext shared;
    shared.module = module;
    shared.sink = sink;
    SCCPContext globalContext;
    globalContext.shared = &shared;
    globalContext.code = nullptr;
    bool changed = globalContext.applyOnGlobalScope(module);
    return changed;
}

bool applySparseConditionalConstantPropagation(IRInst* func, DiagnosticSink* sink)
{
    if (sink && sink->getErrorCount())
        return false;

    SharedSCCPContext shared;
    shared.module = func->getModule();
    shared.sink = sink;

    SCCPContext globalContext;
    globalContext.shared = &shared;
    globalContext.code = nullptr;

    // Run recursive SCCP passes on each child code block.
    return applySparseConditionalConstantPropagationRec(globalContext, func);
}

IRInst* tryConstantFoldInst(IRModule* module, IRInst* inst)
{
    SharedSCCPContext shared;
    shared.module = module;
    SCCPContext instContext;
    instContext.shared = &shared;
    instContext.code = nullptr;
    instContext.builderStorage = IRBuilder(module);
    auto foldResult = instContext.interpretOverLattice(inst);
    if (!foldResult.value)
    {
        return inst;
    }
    inst->replaceUsesWith(foldResult.value);
    return foldResult.value;
}

}

back to top