Raw File
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
// slang-ast-builder.cpp
#include "slang-ast-builder.h"
#include <assert.h>

#include "slang-compiler.h"

namespace Slang {

// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SharedASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


void SharedASTBuilder::init(Session* session)
    m_namePool = session->getNamePool();

    // Save the associated session
    m_session = session;

    // We just want as a place to store allocations of shared types
        RefPtr<ASTBuilder> astBuilder(new ASTBuilder);
        astBuilder->m_sharedASTBuilder = this;
        m_astBuilder = astBuilder.detach();

    // Clear the built in types
    memset(m_builtinTypes, 0, sizeof(m_builtinTypes));

    // We can just iterate over the class pointers.
    // NOTE! That this adds the names of the abstract classes too(!)
    for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i)
        const ReflectClassInfo* info = ASTClassInfo::getInfo(ASTNodeType(i));
        if (info)
            m_sliceToTypeMap.add(UnownedStringSlice(info->m_name), info);
            Name* name = m_namePool->getName(String(info->m_name));
            m_nameToTypeMap.add(name, info);

const ReflectClassInfo* SharedASTBuilder::findClassInfo(const UnownedStringSlice& slice)
    const ReflectClassInfo* typeInfo;
    return m_sliceToTypeMap.tryGetValue(slice, typeInfo) ? typeInfo : nullptr;

SyntaxClass<NodeBase> SharedASTBuilder::findSyntaxClass(const UnownedStringSlice& slice)
    const ReflectClassInfo* typeInfo;
    if (m_sliceToTypeMap.tryGetValue(slice, typeInfo))
        return SyntaxClass<NodeBase>(typeInfo);
    return SyntaxClass<NodeBase>();

const ReflectClassInfo* SharedASTBuilder::findClassInfo(Name* name)
    const ReflectClassInfo* typeInfo;
    return m_nameToTypeMap.tryGetValue(name, typeInfo) ? typeInfo : nullptr;

SyntaxClass<NodeBase> SharedASTBuilder::findSyntaxClass(Name* name)
    const ReflectClassInfo* typeInfo;
    if (m_nameToTypeMap.tryGetValue(name, typeInfo))
        return SyntaxClass<NodeBase>(typeInfo);
    return SyntaxClass<NodeBase>();

Type* SharedASTBuilder::getStringType()
    if (!m_stringType)
        auto stringTypeDecl = findMagicDecl("StringType");
        m_stringType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(stringTypeDecl));
    return m_stringType;

Type* SharedASTBuilder::getNativeStringType()
    if (!m_nativeStringType)
        auto nativeStringTypeDecl = findMagicDecl("NativeStringType");
        m_nativeStringType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(nativeStringTypeDecl));
    return m_nativeStringType;

Type* SharedASTBuilder::getEnumTypeType()
    if (!m_enumTypeType)
        auto enumTypeTypeDecl = findMagicDecl("EnumTypeType");
        m_enumTypeType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(enumTypeTypeDecl));
    return m_enumTypeType;

Type* SharedASTBuilder::getDynamicType()
    if (!m_dynamicType)
        auto dynamicTypeDecl = findMagicDecl("DynamicType");
        m_dynamicType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(dynamicTypeDecl));
    return m_dynamicType;

Type* SharedASTBuilder::getNullPtrType()
    if (!m_nullPtrType)
        auto nullPtrTypeDecl = findMagicDecl("NullPtrType");
        m_nullPtrType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(nullPtrTypeDecl));
    return m_nullPtrType;

Type* SharedASTBuilder::getNoneType()
    if (!m_noneType)
        auto noneTypeDecl = findMagicDecl("NoneType");
        m_noneType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(noneTypeDecl));
    return m_noneType;

Type* SharedASTBuilder::getDiffInterfaceType()
    if (!m_diffInterfaceType)
        auto decl = findMagicDecl("DifferentiableType");
        m_diffInterfaceType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(decl));
    return m_diffInterfaceType;

Type* SharedASTBuilder::getErrorType()
    if (!m_errorType)
        m_errorType = m_astBuilder->getOrCreate<ErrorType>();
    return m_errorType;
Type* SharedASTBuilder::getBottomType()
    if (!m_bottomType)
        m_bottomType = m_astBuilder->getOrCreate<BottomType>();
    return m_bottomType;
Type* SharedASTBuilder::getInitializerListType()
    if (!m_initializerListType)
        m_initializerListType = m_astBuilder->getOrCreate<InitializerListType>();
    return m_initializerListType;
Type* SharedASTBuilder::getOverloadedType()
    if (!m_overloadedType)
        m_overloadedType = m_astBuilder->getOrCreate<OverloadGroupType>();
    return m_overloadedType;

    // Release built in types..
    for (Index i = 0; i < SLANG_COUNT_OF(m_builtinTypes); ++i)
        m_builtinTypes[i] = nullptr;

    if (m_astBuilder)

void SharedASTBuilder::registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modifier)
    auto type = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(decl));
    m_builtinTypes[Index(modifier->tag)] = type;

void SharedASTBuilder::registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier)
    m_builtinRequirementDecls[modifier->kind] = decl;

void SharedASTBuilder::registerMagicDecl(Decl* decl, MagicTypeModifier* modifier)
    // In some cases the modifier will have been applied to the
    // "inner" declaration of a `GenericDecl`, but what we
    // actually want to register is the generic itself.
    auto declToRegister = decl;
    if (auto genericDecl = as<GenericDecl>(decl->parentDecl))
        declToRegister = genericDecl;

    m_magicDecls[modifier->magicName] = declToRegister;

Decl* SharedASTBuilder::findMagicDecl(const String& name)
    return m_magicDecls.getValue(name);

Decl* SharedASTBuilder::tryFindMagicDecl(const String& name)
    auto d = m_magicDecls.tryGetValue(name);
    return d ? *d : nullptr;

// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name):
    // Copy Val deduplication map over so we don't create duplicate Vals that are already
    // existent in the stdlib.
    m_cachedNodes = sharedASTBuilder->getInnerASTBuilder()->m_cachedNodes;

    m_name = "SharedASTBuilder::m_astBuilder";

    for (NodeBase* node : m_dtorNodes)
        const ReflectClassInfo* info = ASTClassInfo::getInfo(node->astNodeType);

Index ASTBuilder::getEpoch()
    return getSharedASTBuilder()->m_session->m_epochId;

void ASTBuilder::incrementEpoch()

NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
    const ReflectClassInfo* info = ASTClassInfo::getInfo(nodeType);
    auto createFunc = info->m_createFunc;
    if (!createFunc)
        return nullptr;

    return (NodeBase*)createFunc(this);

Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName)
    auto declRef = getBuiltinDeclRef(magicTypeName, typeParam);
    auto rsType = DeclRefType::create(this, declRef);
    return rsType;

Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const char* magicTypeName)
    auto declRef = getBuiltinDeclRef(magicTypeName, genericArgs);
    auto rsType = DeclRefType::create(this, declRef);
    return rsType;

PtrType* ASTBuilder::getPtrType(Type* valueType)
    return dynamicCast<PtrType>(getPtrType(valueType, "PtrType"));

// Construct the type `Out<valueType>`
OutType* ASTBuilder::getOutType(Type* valueType)
    return dynamicCast<OutType>(getPtrType(valueType, "OutType"));

InOutType* ASTBuilder::getInOutType(Type* valueType)
    return dynamicCast<InOutType>(getPtrType(valueType, "InOutType"));

RefType* ASTBuilder::getRefType(Type* valueType)
    return dynamicCast<RefType>(getPtrType(valueType, "RefType"));

ConstRefType* ASTBuilder::getConstRefType(Type* valueType)
    return dynamicCast<ConstRefType>(getPtrType(valueType, "ConstRefType"));

OptionalType* ASTBuilder::getOptionalType(Type* valueType)
    auto rsType = getSpecializedBuiltinType(valueType, "OptionalType");
    return as<OptionalType>(rsType);

PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
    return as<PtrTypeBase>(getSpecializedBuiltinType(valueType, ptrTypeName));

ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
    if (!elementCount)
        elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength);
    if (elementCount->getType() != getIntType())
        // Canonicalize constant elementCount to int.
        if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount))
            elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
    Val* args[] = {elementType, elementCount};
    return as<ArrayExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "ArrayExpressionType"));

ConstantBufferType* ASTBuilder::getConstantBufferType(Type* elementType)
    return as<ConstantBufferType>(getSpecializedBuiltinType(elementType, "ConstantBufferType"));

ParameterBlockType* ASTBuilder::getParameterBlockType(Type* elementType)
    return as<ParameterBlockType>(getSpecializedBuiltinType(elementType, "ParameterBlockType"));

HLSLStructuredBufferType* ASTBuilder::getStructuredBufferType(Type* elementType)
    return as<HLSLStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLStructuredBufferType"));

HLSLRWStructuredBufferType* ASTBuilder::getRWStructuredBufferType(Type* elementType)
    return as<HLSLRWStructuredBufferType>(getSpecializedBuiltinType(elementType, "HLSLRWStructuredBufferType"));

SamplerStateType* ASTBuilder::getSamplerStateType()
    return as<SamplerStateType>(getSpecializedBuiltinType(nullptr, "HLSLStructuredBufferType"));

VectorExpressionType* ASTBuilder::getVectorType(
    Type*    elementType,
    IntVal*  elementCount)
    // Canonicalize constant elementCount to int.
    if (auto elementCountConstantInt = as<ConstantIntVal>(elementCount))
        elementCount = getIntVal(getIntType(), elementCountConstantInt->getValue());
    Val* args[] = { elementType, elementCount };
    return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType"));

MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout)
    // Canonicalize constant size arguments to int.
    if (auto rowCountConstantInt = as<ConstantIntVal>(rowCount))
        rowCount = getIntVal(getIntType(), rowCountConstantInt->getValue());
    if (auto colCountConstantInt = as<ConstantIntVal>(colCount))
        colCount = getIntVal(getIntType(), colCountConstantInt->getValue());
    Val* args[] = { elementType, rowCount, colCount, layout };
    return as<MatrixExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "MatrixExpressionType"));

DifferentialPairType* ASTBuilder::getDifferentialPairType(
    Type* valueType,
    Witness* primalIsDifferentialWitness)
    Val* args[] = { valueType, primalIsDifferentialWitness };
    return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));

DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
    DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr));
    return declRef;

bool ASTBuilder::isDifferentiableInterfaceAvailable()
    return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr);

MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier(
    HLSLMeshShaderOutputModifier* modifier,
    Type* elementType,
    IntVal* maxElementCount)

    const char* declName
        = as<HLSLVerticesModifier>(modifier) ? "VerticesType"
        : as<HLSLIndicesModifier>(modifier) ? "IndicesType"
        : as<HLSLPrimitivesModifier>(modifier) ? "PrimitivesType"
        : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr);

    Val* args[] = { elementType, maxElementCount };
    return as<MeshOutputType>(getSpecializedBuiltinType(makeArrayView(args), declName));

Type* ASTBuilder::getDifferentiableInterfaceType()
    return DeclRefType::create(this, getDifferentiableInterfaceDecl());

DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
    auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
    if (auto genericDecl = as<GenericDecl>(decl))
        auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), makeConstArrayViewSingle(genericArg));
        return declRef;
    return makeDeclRef(decl);

DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView<Val*> genericArgs)
    auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
    if (auto genericDecl = as<GenericDecl>(decl))
        auto declRef = getGenericAppDeclRef(makeDeclRef(genericDecl), genericArgs);
        return declRef;
        SLANG_ASSERT(!decl && !genericArgs.getCount());
    return makeDeclRef(decl);

Type* ASTBuilder::getAndType(Type* left, Type* right)
    auto type = getOrCreate<AndType>(left, right);
    return type;

Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* modifiers)
    auto type = getOrCreate<ModifiedType>(base, makeArrayView((Val**)modifiers, modifierCount));
    return type;

Val* ASTBuilder::getUNormModifierVal()
    return getOrCreate<UNormModifierVal>();

Val* ASTBuilder::getSNormModifierVal()
    return getOrCreate<SNormModifierVal>();

Val* ASTBuilder::getNoDiffModifierVal()
    return getOrCreate<NoDiffModifierVal>();

FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType)
    if (!errorType)
        errorType = getOrCreate<BottomType>();
    return getOrCreate<FuncType>(parameters, result, errorType);

TupleType* ASTBuilder::getTupleType(List<Type*>& types)
    return getOrCreate<TupleType>(types.getArrayView());

TypeType* ASTBuilder::getTypeType(Type* type)
    return getOrCreate<TypeType>(type);

TypeEqualityWitness* ASTBuilder::getTypeEqualityWitness(
    Type* type)
    return getOrCreate<TypeEqualityWitness>(type, type);

DeclaredSubtypeWitness* ASTBuilder::getDeclaredSubtypeWitness(
    Type*                   subType,
    Type*                   superType,
    DeclRef<Decl> const&    declRef)
    auto witness = getOrCreate<DeclaredSubtypeWitness>(
        subType, superType, declRef.declRefBase);
    return witness;

SubtypeWitness* ASTBuilder::getTransitiveSubtypeWitness(
    SubtypeWitness* aIsSubtypeOfBWitness,
    SubtypeWitness* bIsSubtypeOfCWitness)
    // Our job is to take the witnesses that `a <: b` and `b <: c`
    // and produce a valid witness that `a <: c`
    // There are some special cases we want to handle, in order
    // to simplify logic elesewhere in the compiler. For example,
    // if either of the input witnesses is a type *equality* witness,
    // then the other witness can be returned as-is.
    // If `a == b`, then the `b <: c` witness is also a witness of `a <: c`.
        return bIsSubtypeOfCWitness;

    // Similarly, if `b == c`, then the `a <: b` witness is a witness for `a <: c`
    if (as<TypeEqualityWitness>(bIsSubtypeOfCWitness))
        return aIsSubtypeOfBWitness;

    // HACK: There is downstream code generation logic that assumes that
    // a `TransitiveSubtypeWitness` will never have a transitive witness
    // as its `b <: c` witness. If we are at risk of creating such a witness here,
    // we will shift things around to make that not be the case:
    if (auto bIsTransitiveSubtypeOfCWitness = as<TransitiveSubtypeWitness>(bIsSubtypeOfCWitness))
        // Let's call the intermediate type here `x`, we know that the `b <: c`
        // witness is based on witnesses that `b <: x` and `x <: c`:
        auto bIsSubtypeOfXWitness = bIsTransitiveSubtypeOfCWitness->getSubToMid();
        auto xIsSubtypeOfCWitness = bIsTransitiveSubtypeOfCWitness->getMidToSup();

        // We can recursively call this operation to produce a witness that
        // `a <: x`, based on the witnesses we already have for `a <: b` and `b <: x`:
        auto aIsSubtypeOfXWitness = getTransitiveSubtypeWitness(

        // Now we can perform a "tail recursive" call to this function (via `goto`
        // to combine the `a <: x` witness with our `x <: c` witness:
        aIsSubtypeOfBWitness = aIsSubtypeOfXWitness;
        bIsSubtypeOfCWitness = xIsSubtypeOfCWitness;
        goto top;

    auto aType = aIsSubtypeOfBWitness->getSub();
    auto cType = bIsSubtypeOfCWitness->getSup();

    // If the right-hand side is a conjunction witness for `B <: C`
    // of the form `(B <: X)&(B <: Y)`, then we have it that `C = X&Y`
    // and we'd rather form a conjunction witness for `A <: C`
    // that is of the form `(A <: X)&(A <: Y)`.
    if(auto bIsSubtypeOfXAndY = as<ConjunctionSubtypeWitness>(bIsSubtypeOfCWitness))
        auto bIsSubtypeOfXWitness = bIsSubtypeOfXAndY->getLeftWitness();
        auto bIsSubtypeOfYWitness = bIsSubtypeOfXAndY->getRightWitness();

        return getConjunctionSubtypeWitness(
            getTransitiveSubtypeWitness(aIsSubtypeOfBWitness, bIsSubtypeOfXWitness),
            getTransitiveSubtypeWitness(aIsSubtypeOfBWitness, bIsSubtypeOfYWitness));

    // If the right-hand witness `R` is of the form `extract(i, W)`, then
    // `W` is a witness that `B <: X&Y&...` for some conjunction, where `C`
    // is one component of that conjunction.
    if(auto bIsSubtypeViaExtraction = as<ExtractFromConjunctionSubtypeWitness>(bIsSubtypeOfCWitness))
        // We decompose the witness `extract(i, W)` to get both
        // the witness `W` that `B <: X&Y&...` as well as the index
        // `i` of `C` within the conjunction.
        auto bIsSubtypeOfConjunction = bIsSubtypeViaExtraction->getConjunctionWitness();
        auto indexOfCInConjunction = bIsSubtypeViaExtraction->getIndexInConjunction();

        // We lift the extraction to the outside of the composition, by
        // forming a witness for `A <: C` that is of the form
        // `extract(i, L . W )`, where `L` is the left-hand witnes (for `A <: B`).
        // The composition `L . W` is a witness that `A <: X&Y&...`, and
        // the `i`th component of it should be a witness that `A <: C`.
        return getExtractFromConjunctionSubtypeWitness(
            getTransitiveSubtypeWitness(aIsSubtypeOfBWitness, bIsSubtypeOfConjunction),

    // If none of the above special cases applied, then we are just going to create
    // a `TransitiveSubtypeWitness` directly.
    // TODO: Identify other cases that we can potentially simplify.
    // It is particularly notable that we do not have simplification rules that
    // detect when the left-hand side of a composition has some particular
    // structure. This may be fine, or it may not; we should write down a more
    // formal set of rules for the allowed structure of our witnesses to
    // guarantee that our simplifications are sufficient.

    TransitiveSubtypeWitness* transitiveWitness = getOrCreate<TransitiveSubtypeWitness>(
    return transitiveWitness;

SubtypeWitness* ASTBuilder::getExtractFromConjunctionSubtypeWitness(
    Type*           subType,
    Type*           superType,
    SubtypeWitness* conjunctionWitness,
    int             indexOfSuperTypeInConjunction)
    // We are taking a witness `W` for `S <: L&R` and
    // using it to produce a witness for `S <: L`
    // or `S <: R`.

    // If it turns out that the witness `W` is itself
    // formed as a conjuction of witnesses: `(S <: L) & (S <: R)`,
    // then we can simply re-use the appropriate sub-witness.
    if (auto conjWitness = as<ConjunctionSubtypeWitness>(conjunctionWitness))
        return conjWitness->getComponentWitness(indexOfSuperTypeInConjunction);

    // TODO: Are there other simplification cases we should be paying attention
    // to here? For example:
    // * What if the original witness is transitive?

    auto witness = getOrCreate<ExtractFromConjunctionSubtypeWitness>(
    return witness;

SubtypeWitness* ASTBuilder::getConjunctionSubtypeWitness(
    Type*           sub,
    Type*           lAndR,
    SubtypeWitness* subIsLWitness,
    SubtypeWitness* subIsRWitness)
    // If a conjunction witness for `S <: L&R` is being formed,
    // where the constituent witnesses for `S <: L` and `S <: R`
    // are themselves extractions of the first and second
    // components, respectively, of a single witness `W`, then
    // we can simply use `W` as-is.
    auto lExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsLWitness);
    auto rExtract = as<ExtractFromConjunctionSubtypeWitness>(subIsRWitness);
    if(lExtract && rExtract)
        if (lExtract->getIndexInConjunction() == 0
            && rExtract->getIndexInConjunction() == 1)
            auto lInner = lExtract->getConjunctionWitness();
            auto rInner = rExtract->getConjunctionWitness();
            if (lInner == rInner)
                return lInner;

    // TODO: Depending on how we decide our canonicalized witnesses
    // should be structured, we could detect the case where the
    // `S <: L` and `S <: R` witnesses are both transitive compositions
    // of the form `X . A` and `X . B`, such that we *could* form
    // a composition around a conjunction - that is, produce
    // `X . (A & B)` rather than `(X . A) & (X . B)`.
    // For now we are favoring putting the composition (transitive
    // witness) deeper, so that we have more chances to expose a
    // conjunction witness at higher levels.

    auto witness = getOrCreate<ConjunctionSubtypeWitness>(
    return witness;

DeclRef<Decl> _getMemberDeclRef(ASTBuilder* builder, DeclRef<Decl> parent, Decl* decl)
    return builder->getMemberDeclRef(parent, decl);

thread_local ASTBuilder* gCurrentASTBuilder = nullptr;

ASTBuilder* getCurrentASTBuilder()
    return gCurrentASTBuilder;

void setCurrentASTBuilder(ASTBuilder* astBuilder)
    gCurrentASTBuilder = astBuilder;

} // namespace Slang
back to top