https://github.com/shader-slang/slang
Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ir-autodiff-pairs.cpp
#include "slang-ir-autodiff-pairs.h"

namespace Slang
{

struct DiffPairLoweringPass : InstPassBase
{
    DiffPairLoweringPass(AutoDiffSharedContext* context) :
        InstPassBase(context->moduleInst->getModule()), 
        pairBuilderStorage(context)
    { 
        pairBuilder = &pairBuilderStorage;
    }

    IRInst* lowerPairType(IRBuilder* builder, IRType* pairType)
    {
        builder->setInsertBefore(pairType);
        auto loweredPairType = pairBuilder->lowerDiffPairType(
            builder,
            pairType);
        return loweredPairType;
    }

    IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
    {
        if (auto makePairInst = as<IRMakeDifferentialPairBase>(inst))
        {
            bool isTrivial = false;
            auto pairType = as<IRDifferentialPairTypeBase>(makePairInst->getDataType());
            if (auto loweredPairType = lowerPairType(builder, pairType))
            {
                builder->setInsertBefore(makePairInst);
                IRInst* result = nullptr;
                if (isTrivial)
                {
                    result = makePairInst->getPrimalValue();
                }
                else
                {
                    IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() };
                    result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
                }
                makePairInst->replaceUsesWith(result);
                makePairInst->removeAndDeallocate();
                return result;
            }
        }

        return nullptr;
    }

    IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
    {
        if (auto getDiffInst = as<IRDifferentialPairGetDifferentialBase>(inst))
        {
            auto pairType = getDiffInst->getBase()->getDataType();
            if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
            {
                pairType = pairPtrType->getValueType();
            }

            if (lowerPairType(builder, pairType))
            {
                builder->setInsertBefore(getDiffInst);
                IRInst* diffFieldExtract = nullptr;
                diffFieldExtract = pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase());
                getDiffInst->replaceUsesWith(diffFieldExtract);
                getDiffInst->removeAndDeallocate();
                return diffFieldExtract;
            }
        }
        else if (auto getPrimalInst = as<IRDifferentialPairGetPrimalBase>(inst))
        {
            auto pairType = getPrimalInst->getBase()->getDataType();
            if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
            {
                pairType = pairPtrType->getValueType();
            }

            if (lowerPairType(builder, pairType))
            {
                builder->setInsertBefore(getPrimalInst);

                IRInst* primalFieldExtract = nullptr;
                primalFieldExtract = pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
                getPrimalInst->replaceUsesWith(primalFieldExtract);
                getPrimalInst->removeAndDeallocate();
                return primalFieldExtract;
            }
        }

        return nullptr;
    }

    bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren)
    {
        bool modified = false;

        processAllInsts([&](IRInst* inst)
            {
                // Make sure the builder is at the right level.
                builder->setInsertInto(instWithChildren);

                switch (inst->getOp())
                {
                case kIROp_DifferentialPairGetDifferential:
                case kIROp_DifferentialPairGetPrimal:
                case kIROp_DifferentialPairGetDifferentialUserCode:
                case kIROp_DifferentialPairGetPrimalUserCode:
                    lowerPairAccess(builder, inst);
                    break;

                case kIROp_MakeDifferentialPairUserCode:
                    lowerMakePair(builder, inst);
                    break;

                default:
                    break;
                }
            });

        OrderedDictionary<IRInst*, IRInst*> pendingReplacements;
        processAllInsts([&](IRInst* inst)
            {
                if (auto pairType = as<IRDifferentialPairTypeBase>(inst))
                {
                    if (auto loweredType = lowerPairType(builder, pairType))
                    {
                        pendingReplacements.add(pairType, loweredType);
                        modified = true;
                    }
                }
            });
        for (auto replacement : pendingReplacements)
        {
            replacement.key->replaceUsesWith(replacement.value);
            replacement.key->removeAndDeallocate();
        }

        return modified;
    }

    bool processModule()
    {
        IRBuilder builder(module);
        return processInstWithChildren(&builder, module->getModuleInst());
    }

    private: 
    
    DifferentialPairTypeBuilder* pairBuilder;

    DifferentialPairTypeBuilder pairBuilderStorage;

};

bool processPairTypes(AutoDiffSharedContext* context)
{
    DiffPairLoweringPass pairLoweringPass(context);
    return pairLoweringPass.processModule();
}

struct DifferentialPairUserCodeTranscribePass : public InstPassBase
{
    DifferentialPairUserCodeTranscribePass(IRModule* module)
        :InstPassBase(module)
    {}

    IRInst* rewritePairType(IRBuilder* builder, IRType* pairType)
    {
        builder->setInsertBefore(pairType);
        auto originalPairType = as<IRDifferentialPairType>(pairType);
        return builder->getDifferentialPairUserCodeType(originalPairType->getValueType(), originalPairType->getWitness());
    }

    IRInst* rewriteMakePair(IRBuilder* builder, IRMakeDifferentialPair* inst)
    {
        auto pairType = as<IRDifferentialPairType>(inst->getFullType());
        builder->setInsertBefore(inst);
        auto newInst = builder->emitMakeDifferentialPairUserCode(
            (IRType*)pairType, inst->getPrimalValue(), inst->getDifferentialValue());
        inst->replaceUsesWith(newInst);
        inst->removeAndDeallocate();
        return newInst;
    }

    IRInst* rewritePairAccess(IRBuilder* builder, IRInst* inst)
    {
        if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
        {
            builder->setInsertBefore(inst);

            auto newInst = builder->emitDifferentialPairGetDifferentialUserCode(
                (IRType*)inst->getFullType(), getDiffInst->getBase());
            inst->replaceUsesWith(newInst);
            inst->removeAndDeallocate();
        }
        else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
        {
            builder->setInsertBefore(inst);
            auto newInst = builder->emitDifferentialPairGetPrimalUserCode(getPrimalInst->getBase());
            inst->replaceUsesWith(newInst);
            inst->removeAndDeallocate();
        }
        return inst;
    }

    bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren)
    {
        SLANG_UNUSED(instWithChildren);

        bool modified = false;

        processAllInsts([&](IRInst* inst)
            {
                switch (inst->getOp())
                {
                case kIROp_DifferentialPairGetDifferential:
                case kIROp_DifferentialPairGetPrimal:
                    rewritePairAccess(builder, inst);
                    break;

                case kIROp_MakeDifferentialPair:
                    rewriteMakePair(builder, as<IRMakeDifferentialPair>(inst));
                    break;

                default:
                    break;
                }
            });

        OrderedDictionary<IRInst*, IRInst*> pendingReplacements;
        processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
            {
                if (auto loweredType = rewritePairType(builder, inst))
                {
                    pendingReplacements.add(inst, loweredType);
                    modified = true;
                }
            });
        for (auto replacement : pendingReplacements)
        {
            replacement.key->replaceUsesWith(replacement.value);
            replacement.key->removeAndDeallocate();
        }

        return modified;
    }

    bool processModule()
    {
        IRBuilder builder(module);
        return processInstWithChildren(&builder, module->getModuleInst());
    }
};

void rewriteDifferentialPairToUserCode(IRModule* module)
{
    DifferentialPairUserCodeTranscribePass pairRewritePass(module);
    pairRewritePass.processModule();
}

}
back to top