https://github.com/shader-slang/slang
Raw File
Tip revision: 939be44ca23476e622dfb24a592383fe2a1da61f authored by Yong He on 26 October 2022, 15:32:24 UTC
Auto synthesis of Differential type (#2466)
Tip revision: 939be44
slang-ir-constexpr.cpp
// slang-ir-constexpr.cpp
#include "slang-ir-constexpr.h"

#include "slang-ir.h"
#include "slang-ir-insts.h"

namespace Slang {

struct PropagateConstExprContext
{
    IRModule* module;
    IRModule* getModule() { return module; }

    DiagnosticSink* sink;

    SharedIRBuilder sharedBuilder;
    IRBuilder builder;

    List<IRInst*> workList;
    HashSet<IRInst*> onWorkList;

    IRBuilder* getBuilder() { return &builder; }

    Session* getSession() { return sharedBuilder.getSession(); }

    DiagnosticSink* getSink() { return sink; }
};

bool isConstExpr(IRType* fullType)
{
    if( auto rateQualifiedType = as<IRRateQualifiedType>(fullType))
    {
        auto rate = rateQualifiedType->getRate();
        if(auto constExprRate = as<IRConstExprRate>(rate))
            return true;
    }

    return false;
}

bool isConstExpr(IRInst* value)
{
    // Certain IR value ops are implicitly `constexpr`
    //
    // TODO: should we just go ahead and make that explicit
    // in the type system?
    switch(value->getOp())
    {
    case kIROp_IntLit:
    case kIROp_FloatLit:
    case kIROp_BoolLit:
    case kIROp_Func:
        return true;

    default:
        break;
    }

    if(isConstExpr(value->getFullType()))
        return true;

    return false;
}

bool opCanBeConstExpr(IROp op)
{
    switch( op )
    {
    case kIROp_IntLit:
    case kIROp_FloatLit:
    case kIROp_BoolLit:
    case kIROp_Add:
    case kIROp_Sub:
    case kIROp_Mul:
    case kIROp_Div:
    case kIROp_IRem:
    case kIROp_FRem:
    case kIROp_Neg:
    case kIROp_Construct:
    case kIROp_makeVector:
    case kIROp_makeArray:
    case kIROp_MakeMatrix:
    // TODO: more cases
        return true;

    default:
        return false;
    }
}

bool opCanBeConstExpr(IRInst* value)
{
    // TODO: realistically need to special-case `call`
    // operations here, so that we check whether the
    // callee function is fixed/known, and if it is
    // whether it has been decoared as constant-foldable

    return opCanBeConstExpr(value->getOp());
}

void markConstExpr(
    PropagateConstExprContext*  context,
    IRInst*                    value)
{
    Slang::markConstExpr(context->getBuilder(), value);
}


// Propagate `constexpr`-ness in a forward direction, from the
// operands of an instruction to the instruction itself.
bool propagateConstExprForward(
    PropagateConstExprContext*  context,
    IRGlobalValueWithCode*      code)
{
    bool anyChanges = false;
    for(;;)
    {
        bool changedThisIteration = false;
        for( auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock() )
        {
            for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() )
            {
                // Instruction already `constexpr`? Then skip it.
                if(isConstExpr(ii))
                    continue;

                // Is the operation one that we can actually make be constexpr?
                if(!opCanBeConstExpr(ii))
                    continue;

                // Are all arguments `constexpr`?
                bool allArgsConstExpr = true;
                UInt argCount = ii->getOperandCount();
                for( UInt aa = 0; aa < argCount; ++aa )
                {
                    auto arg = ii->getOperand(aa);

                    if( !isConstExpr(arg) )
                    {
                        allArgsConstExpr = false;
                        break;
                    }
                }
                if(!allArgsConstExpr)
                    continue;

                // Seems like this operation can/should be made constexpr
                markConstExpr(context, ii);
                changedThisIteration = true;
            }
        }

        if( !changedThisIteration )
            return anyChanges;

        anyChanges = true;
    }
}

void maybeAddToWorkList(
    PropagateConstExprContext*  context,
    IRInst*                     gv)
{
    if( !context->onWorkList.Contains(gv) )
    {
        context->workList.add(gv);
        context->onWorkList.Add(gv);
    }
}

bool maybeMarkConstExpr(
    PropagateConstExprContext*  context,
    IRInst*                    value)
{
    if(isConstExpr(value))
        return false;

    if(!opCanBeConstExpr(value))
        return false;

    markConstExpr(context, value);

    // TODO: we should only allow function parameters to be
    // changed to be `constexpr` when we are compiling "application"
    // code, and not library code.
    // (Or eventually we'd have a rule that only non-`public` symbols
    // can have this kind of propagation applied).

    if(value->getOp() == kIROp_Param)
    {
        auto param = (IRParam*) value;
        auto block = (IRBlock*) param->parent;
        auto code = block->getParent();

        if(block == code->getFirstBlock())
        {
            // We've just changed a function parameter to
            // be `constexpr`. We need to remember that
            // fact so taht we can mark callers of this
            // function as `constexpr` themselves.

            for( auto u = code->firstUse; u; u = u->nextUse )
            {
                auto user = u->getUser();

                switch( user->getOp() )
                {
                case kIROp_Call:
                    {
                        auto inst = (IRCall*) user;
                        auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent());
                        maybeAddToWorkList(context, caller);
                    }
                    break;

                default:
                    break;
                }
            }
        }
    }

    return true;
}

// Propagate `constexpr`-ness in a backward direction, from an instruction
// to its operands.
bool propagateConstExprBackward(
    PropagateConstExprContext*  context,
    IRGlobalValueWithCode*      code)
{
    SharedIRBuilder sharedBuilder(context->getModule());

    IRBuilder builder(sharedBuilder);
    builder.setInsertInto(code);

    bool anyChanges = false;
    for(;;)
    {
        // Note: we are walking the list of blocks and the instructions
        // in each block in reverse order, to maximize the chances that
        // we propagate multiple changes in a each pass.
        //
        // TODO: this should probably all be done with a work list instead,
        // but that requires being able to detect instructions vs. other
        // values.

        bool changedThisIteration = false;
        for( auto bb = code->getLastBlock(); bb; bb = bb->getPrevBlock() )
        {
            for( auto ii = bb->getLastInst(); ii; ii = ii->getPrevInst() )
            {
                if( isConstExpr(ii) )
                {
                    // If this instruction is `constexpr`, then its operands should be too.
                    UInt argCount = ii->getOperandCount();
                    for( UInt aa = 0; aa < argCount; ++aa )
                    {
                        auto arg = ii->getOperand(aa);
                        if(isConstExpr(arg))
                            continue;

                        if(!opCanBeConstExpr(arg))
                            continue;

                        if( maybeMarkConstExpr(context, arg) )
                        {
                            changedThisIteration = true;
                        }
                    }
                }
                else if( ii->getOp() == kIROp_Call )
                {
                    // A non-constexpr call might be calling a function with one or
                    // more constexpr parameters. We should check if we can resolve
                    // the callee for this call statically, and if so try to propagate
                    // constexpr from the parameters back to the arguments.
                    auto callInst = (IRCall*) ii;

                    UInt operandCount = callInst->getOperandCount();

                    UInt firstCallArg = 1;
                    UInt callArgCount = operandCount - firstCallArg;

                    auto callee = callInst->getOperand(0);

                    // If we are calling a generic operation, then
                    // try to follow through the `specialize` chain
                    // and find the callee.
                    //
                    // TODO: This probably shouldn't be required,
                    // since we can hopefully use the type of the
                    // callee in all cases.
                    //
                    while(auto specInst = as<IRSpecialize>(callee))
                    {
                        auto genericInst = as<IRGeneric>(specInst->getBase());
                        if(!genericInst)
                            break;

                        auto returnVal = findGenericReturnVal(genericInst);
                        if(!returnVal)
                            break;

                        callee = returnVal;
                    }

                    auto calleeFunc = as<IRFunc>(callee);
                    if(calleeFunc && isDefinition(calleeFunc))
                    {
                        // We have an IR-level function definition we are calling,
                        // and thus we can propagate `constexpr` information
                        // through its `IRParam`s.

                        auto calleeFuncType = calleeFunc->getDataType();

                        UInt callParamCount = calleeFuncType->getParamCount();
                        SLANG_RELEASE_ASSERT(callParamCount == callArgCount);

                        // If the callee has a definition, then we can read `constexpr`
                        // information off of the parameters of its first IR block.
                        if(auto calleeFirstBlock = calleeFunc->getFirstBlock())
                        {
                            UInt paramCounter = 0;
                            for(auto pp = calleeFirstBlock->getFirstParam(); pp; pp = pp->getNextParam())
                            {
                                UInt paramIndex = paramCounter++;

                                auto param = pp;
                                auto arg = callInst->getOperand(firstCallArg + paramIndex);

                                if(isConstExpr(param))
                                {
                                    if(maybeMarkConstExpr(context, arg))
                                    {
                                        changedThisIteration = true;
                                    }
                                }
                            }
                        }
                    }
                    else
                    {
                        // If we don't have a concrete callee function
                        // definition, then we need to extract the
                        // type of the callee instruction, and try to work
                        // with that.
                        //
                        // Note that this does not allow us to propagate
                        // `constexpr` information from the body of a callee
                        // back to call sites.
                        auto calleeType = callee->getDataType();
                        if(auto caleeFuncType = as<IRFuncType>(calleeType))
                        {
                            auto paramCount = caleeFuncType->getParamCount();
                            for( UInt pp = 0; pp < paramCount; ++pp )
                            {
                                auto paramType = caleeFuncType->getParamType(pp);
                                auto arg = callInst->getOperand(firstCallArg + pp);
                                if( isConstExpr(paramType) )
                                {
                                    if( maybeMarkConstExpr(context, arg) )
                                    {
                                        changedThisIteration = true;
                                    }
                                }
                            }
                        }
                    }
                }
            }

            if( bb != code->getFirstBlock() )
            {
                // A parameter in anything butr the first block is
                // conceptually a phi node, which means its operands
                // are the corresponding values from the terminating
                // branch in a predecessor block.

                UInt paramCounter = 0;
                for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() )
                {
                    UInt paramIndex = paramCounter++;

                    if(!isConstExpr(pp))
                        continue;

                    for(auto pred : bb->getPredecessors())
                    {
                        auto terminator = pred->getLastInst();
                        if(terminator->getOp() != kIROp_unconditionalBranch)
                            continue;

                        UInt operandIndex = paramIndex + 1;
                        SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount());

                        auto operand = terminator->getOperand(operandIndex);
                        if( maybeMarkConstExpr(context, operand) )
                        {
                            changedThisIteration = true;
                        }
                    }
                }
            }

        }

        if( !changedThisIteration )
            return anyChanges;

        anyChanges = true;
    }
}

// Validate use of `constexpr` within a function (in particular,
// diagnose places where a value that must be contexpr depends
// on a value that cannot be)
void validateConstExpr(
    PropagateConstExprContext*  context,
    IRGlobalValueWithCode*      code)
{
    for( auto bb = code->getFirstBlock(); bb; bb = bb->getNextBlock() )
    {
        for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() )
        {
            if(isConstExpr(ii))
            {
                // For an instruction that must be `constexpr`, we need
                // to ensure that its argumenst are all `constexpr`

                UInt argCount = ii->getOperandCount();
                for( UInt aa = 0; aa < argCount; ++aa )
                {
                    auto arg = ii->getOperand(aa);

                    if( !isConstExpr(arg) )
                    {
                        // Diagnose the failure.

                        context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant);

                        break;
                    }
                }
            }
        }
    }
}

void propagateConstExpr(
    IRModule*       module,
    DiagnosticSink* sink)
{
    PropagateConstExprContext context;
    context.module = module;
    context.sink = sink;
    context.sharedBuilder.init(module);
    context.builder.init(context.sharedBuilder);

    // We need to propagate information both forward and backward.
    //
    // In the forward direction we need to check if all of the operands
    // to an instruction are `constexpr` *and* if the operation is
    // one that can conceptually be "promoted" to the constexpr rate.
    //
    // In the backward direction, if an instruction has already been
    // marked as needing to be `constexpr`, then its operands had
    // better be too.
    //
    // The backward direction needs to be interprocedural, because
    // a parameter to a function might be `constexpr`, so that callers
    // of that function would need to be marked too. If backwards
    // propagation in any of the callers leads to some of their
    // parameters being marked constexpr, then we would need to
    // revisit their callers.

    // We will build an initial work list with all of the global values in it.
    
    for( auto ii : module->getGlobalInsts() )
    {
        maybeAddToWorkList(&context, ii);
    }

    // We will iterate applying propagation to one global value at a time
    // until we run out.
    while( context.workList.getCount() )
    {
        auto gv = context.workList[0];
        context.workList.fastRemoveAt(0);
        context.onWorkList.Remove(gv);

        switch( gv->getOp() )
        {
        default:
            break;

        case kIROp_Func:
        case kIROp_GlobalVar:
            {
                IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv;

                for( ;;)
                {
                    bool anyChange = false;
                    if( propagateConstExprForward(&context, code) )
                    {
                        anyChange = true;
                    }
                    if( propagateConstExprBackward(&context, code) )
                    {
                        anyChange = true;
                    }
                    if(!anyChange)
                        break;
                }
            }
            break;
        }
    }

    // Okay, we've processed all our functions and found a steady state.
    // Now we need to try and issue diagnostics for any IR values where
    // we find that they are *required* to be `constexpr`, but *cannot*
    // be, for some reason.

    for(auto ii : module->getGlobalInsts())
    {
        switch( ii->getOp() )
        {
        default:
            break;

        case kIROp_Func:
        case kIROp_GlobalVar:
            {
                IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) ii;
                validateConstExpr(&context, code);
            }
            break;
        }
    }

}

}
back to top