// slang-ast-dump.h #ifndef SLANG_AST_BUILDER_H #define SLANG_AST_BUILDER_H #include #include "slang-ast-support-types.h" #include "slang-ast-all.h" #include "../core/slang-type-traits.h" #include "../core/slang-memory-arena.h" namespace Slang { class SharedASTBuilder : public RefObject { friend class ASTBuilder; public: void registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modifier); void registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier); void registerMagicDecl(Decl* decl, MagicTypeModifier* modifier); /// Get the string type Type* getStringType(); /// Get the native string type Type* getNativeStringType(); /// Get the enum type type Type* getEnumTypeType(); /// Get the __Dynamic type Type* getDynamicType(); /// Get the NullPtr type Type* getNullPtrType(); /// Get the NullPtr type Type* getNoneType(); /// Get the `IDifferentiable` type Type* getDiffInterfaceType(); Type* getErrorType(); Type* getBottomType(); Type* getInitializerListType(); Type* getOverloadedType(); const ReflectClassInfo* findClassInfo(Name* name); SyntaxClass findSyntaxClass(Name* name); const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice); SyntaxClass findSyntaxClass(const UnownedStringSlice& slice); // Look up a magic declaration by its name Decl* findMagicDecl(String const& name); Decl* tryFindMagicDecl(String const& name); Decl* findBuiltinRequirementDecl(BuiltinRequirementKind kind) { return m_builtinRequirementDecls.getValue(kind); } /// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the Session. NamePool* getNamePool() { return m_namePool; } /// Must be called before used void init(Session* session); SharedASTBuilder(); ~SharedASTBuilder(); ASTBuilder* getInnerASTBuilder() { return m_astBuilder; } protected: // State shared between ASTBuilders Type* m_errorType = nullptr; Type* m_bottomType = nullptr; Type* m_initializerListType = nullptr; Type* m_overloadedType = nullptr; // The following types are created lazily, such that part of their definition // can be in the standard library // // Note(tfoley): These logically belong to `Type`, // but order-of-declaration stuff makes that tricky // // TODO(tfoley): These should really belong to the compilation context! // Type* m_stringType = nullptr; Type* m_nativeStringType = nullptr; Type* m_enumTypeType = nullptr; Type* m_dynamicType = nullptr; Type* m_nullPtrType = nullptr; Type* m_noneType = nullptr; Type* m_diffInterfaceType = nullptr; Type* m_builtinTypes[Index(BaseType::CountOf)]; Dictionary m_magicDecls; Dictionary m_builtinRequirementDecls; Dictionary m_sliceToTypeMap; Dictionary m_nameToTypeMap; NamePool* m_namePool = nullptr; // This is a private builder used for these shared types ASTBuilder* m_astBuilder = nullptr; Session* m_session = nullptr; Index m_id = 1; }; struct ValKey { Val* val; HashCode hashCode; ValKey() = default; ValKey(Val* v) { val = v; Hasher hasher; hasher.hashValue(v->astNodeType); for (auto& operand : v->m_operands) hasher.hashValue(operand.values.intOperand); hashCode = hasher.getResult(); } bool operator==(ValKey other) const { if (val == other.val) return true; if (hashCode != other.hashCode) return false; if (val->astNodeType != other.val->astNodeType) return false; if (val->m_operands.getCount() != other.val->m_operands.getCount()) return false; for (Index i = 0; i < val->m_operands.getCount(); i++) if (val->m_operands[i].values.intOperand != other.val->m_operands[i].values.intOperand) return false; return true; } bool operator==(const ValNodeDesc& desc) const { if (hashCode != desc.getHashCode()) return false; if (val->astNodeType != desc.type) return false; if (val->m_operands.getCount() != desc.operands.getCount()) return false; for (Index i = 0; i < val->m_operands.getCount(); i++) if (val->m_operands[i].values.intOperand != desc.operands[i].values.intOperand) return false; return true; } HashCode getHashCode() const { return hashCode; } }; // Add a specialization which can hash both ValKey and ValNodeDesc template<> struct Hash { using is_transparent = void; auto operator()(const ValKey& k) const { return k.getHashCode(); } auto operator()(const ValNodeDesc& k) const { return Hash{}(k); } }; // A functor which can compare ValKey for equality with ValNodeDesc struct ValKeyEqual { using is_transparent = void; bool operator()(const Slang::ValKey& a, const Slang::ValKey& b) const { return a == b; } bool operator()(const Slang::ValNodeDesc& a, const Slang::ValKey& b) const { return b == a; } }; class ASTBuilder : public RefObject { friend class SharedASTBuilder; public: Val* _getOrCreateImpl(ValNodeDesc&& desc) { if (auto found = m_cachedNodes.tryGetValue(desc)) return *found; auto node = as(createByNodeType(desc.type)); SLANG_ASSERT(node); for (auto& operand : desc.operands) node->m_operands.add(operand); auto result = node; m_cachedNodes.add(ValKey(node), _Move(node)); return result; } /// A cache for AST nodes that are entirely defined by their node type, with /// no need for additional state. Dictionary, ValKeyEqual> m_cachedNodes; Dictionary> m_cachedGenericDefaultArgs; /// Create AST types template T* createImpl() { auto alloced = m_arena.allocate(sizeof(T)); memset(alloced, 0, sizeof(T)); auto result = _initAndAdd(new (alloced) T); return result; } template T* createImpl(TArgs&&... args) { auto alloced = m_arena.allocate(sizeof(T)); memset(alloced, 0, sizeof(T)); auto result = _initAndAdd(new (alloced) T(std::forward(args)...)); return result; } template T* create() { static_assert(!IsBaseOf::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); return createImpl(); } template T* create(TArgs&&... args) { static_assert(!IsBaseOf::Value, "ASTBuilder::create cannot be used to create a Val, use getOrCreate instead."); return createImpl(args...); } public: // For compile time check to see if thing being constructed is an AST type template struct IsValidType { enum { Value = IsBaseOf::Value }; }; Index getEpoch(); void incrementEpoch(); MemoryArena& getArena() { return m_arena; } template SLANG_FORCE_INLINE T* getOrCreate(TArgs ... args) { SLANG_COMPILE_TIME_ASSERT(IsValidType::Value); ValNodeDesc desc; desc.type = T::kType; addOrAppendToNodeList(desc.operands, args...); desc.init(); auto result = (T*)_getOrCreateImpl(_Move(desc)); return result; } template SLANG_FORCE_INLINE T* getOrCreate() { SLANG_COMPILE_TIME_ASSERT(IsValidType::Value); ValNodeDesc desc; desc.type = T::kType; desc.init(); auto result = (T*)_getOrCreateImpl(_Move(desc)); return result; } InterfaceDecl* createInterfaceDecl(SourceLoc loc) { auto interfaceDecl = create(); // Always include a `This` member and a `This:IThisInterface` member. auto thisDecl = create(); thisDecl->nameAndLoc.name = m_sharedASTBuilder->getNamePool()->getName(UnownedStringSlice("This", 4)); thisDecl->nameAndLoc.loc = loc; interfaceDecl->addMember(thisDecl); auto thisConstraint = create(); thisConstraint->loc = loc; thisDecl->addMember(thisConstraint); return interfaceDecl; } template DeclRef getDirectDeclRef(T* decl, typename std::enable_if_t>* = nullptr) { return DeclRef(decl); } template DeclRef getMemberDeclRef(DeclRef parent, T* memberDecl) { if (!parent) return getDirectDeclRef(memberDecl); // A Generic value/type ParamDecl is always referred to directly. if (as(memberDecl) || as(memberDecl)) return getDirectDeclRef(memberDecl); if (as(memberDecl) && !as(memberDecl->parentDecl)) return as(parent); if (auto parentMemberDeclRef = as(parent.declRefBase)) { return DeclRef(getMemberDeclRef(parentMemberDeclRef->getParent(), memberDecl)); } else if (auto lookupDeclRef = as(parent.declRefBase)) { // Handle some specicial case rules due to the way some of our builtin decls are // represented. // - Member(Lookup(w, This), x) ==> Lookup(w, X) // Lookup of x from This is a lookup from w directly. // - Member(Lookup(w, someExtension), x) ==> Lookup(w, X) // Lookup of a decl defined in an extension is to lookup directly. // - Member(Lookup(w, AssociatedType), TypeConstraintDecl) ==> Lookup(w, TypeConstraintDecl) // Type constraint of an associated type is defined directly in w. auto parentDeclKind = lookupDeclRef->getDecl()->astNodeType; switch (parentDeclKind) { case ASTNodeType::ThisTypeDecl: case ASTNodeType::ExtensionDecl: case ASTNodeType::AssocTypeDecl: return getLookupDeclRef(lookupDeclRef->getLookupSource(), lookupDeclRef->getWitness(), memberDecl); default: break; } } else if (auto directDeclRef = as(parent.declRefBase)) { return makeDeclRef(memberDecl); } #if _DEBUG // Verify that member is indeed a member of parent. auto parentDecl = parent.getDecl(); while (as(parentDecl)) parentDecl = parentDecl->parentDecl; bool foundParent = false; for (Decl* dd = memberDecl; dd; dd = dd->parentDecl) { if (dd == parentDecl) { foundParent = true; break; } } SLANG_ASSERT(foundParent); #endif return DeclRef(getOrCreate(memberDecl, parent.declRefBase)); } ConstantIntVal* getIntVal(Type* type, IntegerLiteralValue value) { return getOrCreate(type, value); } TypeCastIntVal* getTypeCastIntVal(Type* type, Val* base) { // Unwrap any existing type casts. while (auto baseTypeCast = as(base)) base = baseTypeCast->getBase(); return getOrCreate(type, base); } DeclRef getGenericAppDeclRef(DeclRef genericDeclRef, ConstArrayView args, Decl* innerDecl = nullptr) { if (!innerDecl) innerDecl = genericDeclRef.getDecl()->inner; return getOrCreate(innerDecl, genericDeclRef, args); } DeclRef getGenericAppDeclRef(DeclRef genericDeclRef, Val::OperandView args, Decl* innerDecl = nullptr) { if (!innerDecl) innerDecl = genericDeclRef.getDecl()->inner; return getOrCreate(innerDecl, genericDeclRef, args); } LookupDeclRef* getLookupDeclRef(Type* base, SubtypeWitness* subtypeWitness, Decl* declToLookup) { auto result = getOrCreate(declToLookup, base, subtypeWitness); return result; } LookupDeclRef* getLookupDeclRef(SubtypeWitness* subtypeWitness, Decl* declToLookup) { return getLookupDeclRef(subtypeWitness->getSub(), subtypeWitness, declToLookup); } NodeBase* createByNodeType(ASTNodeType nodeType); /// Get the built in types SLANG_FORCE_INLINE Type* getBoolType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Bool)]; } SLANG_FORCE_INLINE Type* getHalfType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Half)]; } SLANG_FORCE_INLINE Type* getFloatType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Float)]; } SLANG_FORCE_INLINE Type* getDoubleType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Double)]; } SLANG_FORCE_INLINE Type* getIntType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int)]; } SLANG_FORCE_INLINE Type* getInt64Type() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Int64)]; } SLANG_FORCE_INLINE Type* getIntPtrType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::IntPtr)]; } SLANG_FORCE_INLINE Type* getUIntType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt)]; } SLANG_FORCE_INLINE Type* getUInt64Type() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UInt64)]; } SLANG_FORCE_INLINE Type* getUIntPtrType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::UIntPtr)]; } SLANG_FORCE_INLINE Type* getVoidType() { return m_sharedASTBuilder->m_builtinTypes[Index(BaseType::Void)]; } /// Get a builtin type by the BaseType SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) { return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; } Type* getSpecializedBuiltinType(Type* typeParam, const char* magicTypeName); Type* getSpecializedBuiltinType(ArrayView genericArgs, const char* magicTypeName); Type* getInitializerListType() { return m_sharedASTBuilder->getInitializerListType(); } Type* getOverloadedType() { return m_sharedASTBuilder->getOverloadedType(); } Type* getErrorType() { return m_sharedASTBuilder->getErrorType(); } Type* getBottomType() { return m_sharedASTBuilder->getBottomType(); } Type* getStringType() { return m_sharedASTBuilder->getStringType(); } Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); } Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); } Type* getEnumTypeType() { return m_sharedASTBuilder->getEnumTypeType(); } Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); } // Construct the type `Ptr`, where `Ptr` // is looked up as a builtin type. PtrType* getPtrType(Type* valueType); // Construct the type `Out` OutType* getOutType(Type* valueType); // Construct the type `InOut` InOutType* getInOutType(Type* valueType); // Construct the type `Ref` RefType* getRefType(Type* valueType); // Construct the type `ConstRef` ConstRefType* getConstRefType(Type* valueType); // Construct the type `Optional` OptionalType* getOptionalType(Type* valueType); // Construct a pointer type like `Ptr`, but where // the actual type name for the pointer type is given by `ptrTypeName` PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName); ArrayExpressionType* getArrayType(Type* elementType, IntVal* elementCount); VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount); MatrixExpressionType* getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout); ConstantBufferType* getConstantBufferType(Type* elementType); ParameterBlockType* getParameterBlockType(Type* elementType); HLSLStructuredBufferType* getStructuredBufferType(Type* elementType); HLSLRWStructuredBufferType* getRWStructuredBufferType(Type* elementType); SamplerStateType* getSamplerStateType(); DifferentialPairType* getDifferentialPairType( Type* valueType, Witness* primalIsDifferentialWitness); DeclRef getDifferentiableInterfaceDecl(); Type* getDifferentiableInterfaceType(); bool isDifferentiableInterfaceAvailable(); MeshOutputType* getMeshOutputTypeFromModifier( HLSLMeshShaderOutputModifier* modifier, Type* elementType, IntVal* maxElementCount); DeclRef getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg); DeclRef getBuiltinDeclRef(const char* builtinMagicTypeName, ArrayView genericArgs); Type* getAndType(Type* left, Type* right); Type* getModifiedType(Type* base, Count modifierCount, Val* const* modifiers); Type* getModifiedType(Type* base, List const& modifiers) { return getModifiedType(base, modifiers.getCount(), modifiers.getBuffer()); } Val* getUNormModifierVal(); Val* getSNormModifierVal(); Val* getNoDiffModifierVal(); TupleType* getTupleType(List& types); FuncType* getFuncType(ArrayView parameters, Type* result, Type* errorType = nullptr); TypeType* getTypeType(Type* type); /// Produce a witness that `T : T` for any type `T` TypeEqualityWitness* getTypeEqualityWitness( Type* type); DeclaredSubtypeWitness* getDeclaredSubtypeWitness( Type* subType, Type* superType, DeclRef const& declRef); /// Produce a witness that `A <: C` given witnesses that `A <: B` and `B <: C` SubtypeWitness* getTransitiveSubtypeWitness( SubtypeWitness* aIsSubtypeOfBWitness, SubtypeWitness* bIsSubtypeOfCWitness); /// Produce a witness that `T <: L` or `T <: R` given `T <: L&R` SubtypeWitness* getExtractFromConjunctionSubtypeWitness( Type* subType, Type* superType, SubtypeWitness* subIsSubtypeOfConjunction, int indexOfSuperTypeInConjunction); /// Produce a witnes that `S <: L&R` given witnesses that `S <: L` and `S <: R` SubtypeWitness* getConjunctionSubtypeWitness( Type* sub, Type* lAndR, SubtypeWitness* subIsLWitness, SubtypeWitness* subIsRWitness); /// Helpers to get type info from the SharedASTBuilder const ReflectClassInfo* findClassInfo(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findClassInfo(slice); } SyntaxClass findSyntaxClass(const UnownedStringSlice& slice) { return m_sharedASTBuilder->findSyntaxClass(slice); } const ReflectClassInfo* findClassInfo(Name* name) { return m_sharedASTBuilder->findClassInfo(name); } SyntaxClass findSyntaxClass(Name* name) { return m_sharedASTBuilder->findSyntaxClass(name); } MemoryArena& getMemoryArena() { return m_arena; } /// Get the shared AST builder SharedASTBuilder* getSharedASTBuilder() { return m_sharedASTBuilder; } /// Get the global session Session* getGlobalSession() { return m_sharedASTBuilder->m_session; } Index getId() { return m_id; } /// Ctor ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name); /// Dtor ~ASTBuilder(); protected: // Special default Ctor that can only be used by SharedASTBuilder ASTBuilder(); template SLANG_FORCE_INLINE T* _initAndAdd(T* node) { SLANG_COMPILE_TIME_ASSERT(IsValidType::Value); node->init(T::kType, this); // Only add it if it has a dtor that does some work if (!std::is_trivially_destructible::value) { // Keep such that dtor can be run on ASTBuilder being dtored m_dtorNodes.add(node); } if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Val::kType))) { auto val = (Val*)(node); val->m_resolvedValEpoch = getEpoch(); } return node; } String m_name; Index m_id; /// List of all nodes that require being dtored when ASTBuilder is dtored List m_dtorNodes; SharedASTBuilder* m_sharedASTBuilder; MemoryArena m_arena; }; // Retrieves the ASTBuilder for the current compilation session. ASTBuilder* getCurrentASTBuilder(); // Sets the ASTBuilder for the current compilation session. void setCurrentASTBuilder(ASTBuilder* astBuilder); struct SetASTBuilderContextRAII { ASTBuilder* previousASTBuilder = nullptr; SetASTBuilderContextRAII(ASTBuilder* astBuilder) { previousASTBuilder = getCurrentASTBuilder(); setCurrentASTBuilder(astBuilder); } ~SetASTBuilderContextRAII() { setCurrentASTBuilder(previousASTBuilder); } }; #define SLANG_AST_BUILDER_RAII(astBuilder) SetASTBuilderContextRAII _setASTBuilderContextRAII(astBuilder) } // namespace Slang #endif