https://github.com/shader-slang/slang
Tip revision: 5902acdabc4445a65741a7a6a3a95f223e301059 authored by Yong He on 23 January 2024, 07:19:40 UTC
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
[LSP] Fetch configs directly from didConfigurationChanged message. (#3478)
Tip revision: 5902acd
slang-ast-val.h
// slang-ast-val.h
#pragma once
#include "slang-ast-base.h"
namespace Slang {
// Syntax class definitions for compile-time values.
class DirectDeclRef : public DeclRefBase
{
public:
SLANG_AST_CLASS(DirectDeclRef)
DirectDeclRef(Decl* decl)
{
setOperands(decl);
}
DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
void _toTextOverride(StringBuilder& out);
Val* _resolveImplOverride();
DeclRefBase* _getBaseOverride();
};
// Represent an static member of a base decl.
// Note that we automatically fold the DeclRef if the path is known to be static.
// For example, MemberDeclRef(DirectDeclRef(A), B) ==> DirectDeclRef(B),
// and MemberDeclRef(MemberDeclRef(A, B), C) ==> MemberDeclRef(A, C).
//
class MemberDeclRef : public DeclRefBase
{
public:
SLANG_AST_CLASS(MemberDeclRef);
DeclRefBase* getParentOperand() { return as<DeclRefBase>(getOperand(1)); }
MemberDeclRef(Decl* decl, DeclRefBase* parent)
{
setOperands(decl, parent);
}
DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
void _toTextOverride(StringBuilder& out);
Val* _resolveImplOverride();
DeclRefBase* _getBaseOverride();
};
// Represent a lookup of SuperType::`m_decl` from `lookupSourceType` type that we know conforms to SuperType.
class LookupDeclRef : public DeclRefBase
{
public:
SLANG_AST_CLASS(LookupDeclRef);
// m_decl represents the decl in SuperType that we want to lookup.
// The source type that we are looking up from.
Type* getLookupSource()
{
return as<Type>(getOperand(1));
}
// Witness that `lookupSourceType`:SuperType.
SubtypeWitness* getWitness()
{
return as<SubtypeWitness>(getOperand(2));
}
LookupDeclRef(Decl* declToLookup, Type* lookupSource, SubtypeWitness* witness)
{
setOperands(declToLookup, lookupSource, witness);
}
Decl* getSupDecl();
DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
void _toTextOverride(StringBuilder& out);
Val* _resolveImplOverride();
DeclRefBase* _getBaseOverride();
private:
Val* tryResolve(SubtypeWitness* newWitness, Type* newLookupSource);
};
// Represents a specialization of a generic decl.
class GenericAppDeclRef : public DeclRefBase
{
public:
SLANG_AST_CLASS(GenericAppDeclRef);
DeclRefBase* getGenericDeclRef() { return as<DeclRefBase>(getOperand(1)); }
Index getArgCount() { return getOperandCount() - 2; }
Val* getArg(Index index) { return getOperand(index + 2); }
GenericAppDeclRef(Decl* innerDecl, DeclRefBase* genericDeclRef, OperandView<Val> args)
{
m_operands.add(ValNodeOperand(innerDecl));
m_operands.add(ValNodeOperand(genericDeclRef));
for (auto arg : args)
{
m_operands.add(ValNodeOperand(arg));
}
}
GenericAppDeclRef(Decl* innerDecl, DeclRefBase* genericDeclRef, ConstArrayView<Val*> args)
{
m_operands.add(ValNodeOperand(innerDecl));
m_operands.add(ValNodeOperand(genericDeclRef));
for (auto arg : args)
{
m_operands.add(ValNodeOperand(arg));
}
}
GenericDecl* getGenericDecl();
OperandView<Val> getArgs() { return OperandView<Val>(this, 2, getArgCount()); }
DeclRefBase* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
void _toTextOverride(StringBuilder& out);
Val* _resolveImplOverride();
DeclRefBase* _getBaseOverride();
};
// A compile-time integer (may not have a specific concrete value)
class IntVal : public Val
{
SLANG_ABSTRACT_AST_CLASS(IntVal)
Type* getType() { return as<Type>(getOperand(0)); }
Val* _resolveImplOverride() { return this; }
};
// Trivial case of a value that is just a constant integer
class ConstantIntVal : public IntVal
{
SLANG_AST_CLASS(ConstantIntVal)
IntegerLiteralValue getValue() { return getIntConstOperand(1); }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
ConstantIntVal(Type* inType, IntegerLiteralValue inValue)
{
setOperands(inType, inValue);
}
};
// The logical "value" of a reference to a generic value parameter
class GenericParamIntVal : public IntVal
{
SLANG_AST_CLASS(GenericParamIntVal)
DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(1)); }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef)
{
setOperands(inType, inDeclRef);
}
};
class TypeCastIntVal : public IntVal
{
SLANG_AST_CLASS(TypeCastIntVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
Val* getBase() { return getOperand(1); }
TypeCastIntVal(Type* inType, Val* inBase)
{
setOperands(inType, inBase);
}
static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink);
};
// An compile time int val as result of some general computation.
class FuncCallIntVal : public IntVal
{
SLANG_AST_CLASS(FuncCallIntVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
DeclRef<Decl> getFuncDeclRef() { return as<DeclRefBase>(getOperand(1)); }
Type* getFuncType() { return as<Type>(getOperand(2)); }
OperandView<IntVal> getArgs() { return OperandView<IntVal>(this, 3, getOperandCount() - 3); }
Index getArgCount() { return getOperandCount() - 3; }
FuncCallIntVal(Type* inType, DeclRef<Decl> inFuncDeclRef, Type* inFuncType, ArrayView<IntVal*> inArgs)
{
setOperands(inType, inFuncDeclRef, inFuncType);
for (auto arg : inArgs)
m_operands.add(ValNodeOperand(arg));
}
static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink);
};
class WitnessLookupIntVal : public IntVal
{
SLANG_AST_CLASS(WitnessLookupIntVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
SubtypeWitness* getWitness() { return as<SubtypeWitness>(getOperand(1)); }
Decl* getKey() { return as<Decl>(getDeclOperand(2)); }
WitnessLookupIntVal(Type* inType, SubtypeWitness* witness, Decl* key)
{
setOperands(inType, witness, key);
}
static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key);
static Val* tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type);
};
// polynomial expression "2*a*b^3 + 1" will be represented as:
// { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] }
class PolynomialIntValFactor : public Val
{
SLANG_AST_CLASS(PolynomialIntValFactor)
public:
IntVal* getParam() const { return as<IntVal>(getOperand(0)); }
IntegerLiteralValue getPower() const { return getIntConstOperand(1); }
PolynomialIntValFactor(IntVal* inParam, IntegerLiteralValue inPower)
{
setOperands(inParam, inPower);
}
Val* _resolveImplOverride();
// for sorting only.
bool operator<(const PolynomialIntValFactor& other) const
{
if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
{
if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
{
if (thisGenParam->equals(thatGenParam))
return getPower() < other.getPower();
else
return thisGenParam->getDeclRef().getDecl() < thatGenParam->getDeclRef().getDecl();
}
else
{
return true;
}
}
else
{
if (const auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
{
return false;
}
return getParam() == other.getParam() ? getPower() < other.getPower() : getParam() < other.getParam();
}
}
// for sorting only.
bool operator==(const PolynomialIntValFactor& other) const
{
if (auto thisGenParam = as<GenericParamIntVal>(getParam()))
{
if (auto thatGenParam = as<GenericParamIntVal>(other.getParam()))
{
if (thisGenParam->equals(thatGenParam) && getPower() == other.getPower())
return true;
}
return false;
}
return getPower() == other.getPower() && getParam() == other.getParam();
}
bool equals(const PolynomialIntValFactor& other) const
{
return getPower() == other.getPower() && getParam()->equals(other.getParam());
}
};
class PolynomialIntValTerm : public Val
{
SLANG_AST_CLASS(PolynomialIntValTerm)
public:
IntegerLiteralValue getConstFactor() const { return getIntConstOperand(0); }
OperandView<PolynomialIntValFactor> getParamFactors() const { return OperandView<PolynomialIntValFactor>(this, 1, getOperandCount() - 1); }
Val* _resolveImplOverride();
PolynomialIntValTerm(IntegerLiteralValue inConstFactor, ArrayView<PolynomialIntValFactor*> inParamFactors)
{
setOperands(inConstFactor);
addOperands(inParamFactors);
}
PolynomialIntValTerm(IntegerLiteralValue inConstFactor, OperandView<PolynomialIntValFactor> inParamFactors)
{
setOperands(inConstFactor);
addOperands(inParamFactors);
}
bool canCombineWith(const PolynomialIntValTerm& other) const
{
auto paramFactors = getParamFactors();
if (paramFactors.getCount() != other.getParamFactors().getCount())
return false;
for (Index i = 0; i < getParamFactors().getCount(); i++)
{
if (!paramFactors[i]->equals(*other.getParamFactors()[i]))
return false;
}
return true;
}
bool operator<(const PolynomialIntValTerm& other) const
{
auto constFactor = getConstFactor();
auto paramFactors = getParamFactors();
if (constFactor < other.getConstFactor())
return true;
else if (constFactor == other.getConstFactor())
{
auto otherParamFactors = other.getParamFactors();
for (Index i = 0; i < paramFactors.getCount(); i++)
{
if (i >= otherParamFactors.getCount())
return false;
if (*(paramFactors[i]) < *(otherParamFactors[i]))
return true;
if (*(paramFactors[i]) == *(otherParamFactors[i]))
{
}
else
{
return false;
}
}
}
return false;
}
};
class PolynomialIntVal : public IntVal
{
SLANG_AST_CLASS(PolynomialIntVal)
public:
IntegerLiteralValue getConstantTerm() { return getIntConstOperand(1); };
OperandView<PolynomialIntValTerm> getTerms() { return OperandView<PolynomialIntValTerm>(this, 2, getOperandCount() - 2); };
bool isConstant() { return getOperandCount() == 1; }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
static IntVal* neg(ASTBuilder* astBuilder, IntVal* base);
static IntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
static IntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
static IntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
PolynomialIntVal(Type* inType, IntegerLiteralValue inConstantTerm, ArrayView<PolynomialIntValTerm*> inTerms)
{
setOperands(inType, inConstantTerm);
addOperands(inTerms);
}
};
/// An unknown integer value indicating an erroneous sub-expression
class ErrorIntVal : public IntVal
{
SLANG_AST_CLASS(ErrorIntVal)
ErrorIntVal(Type* inType) { setOperands(inType); }
// TODO: We should probably eventually just have an `ErrorVal` here
// and have all `Val`s that represent ordinary values hold their
// `Type` so that we can have an `ErrorVal` of any type.
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride() { return this; }
};
// A witness to the fact that some proposition is true, encoded
// at the level of the type system.
//
// Given a generic like:
//
// void example<L>(L light)
// where L : ILight
// { ... }
//
// a call to `example()` needs two things for us to be sure
// it is valid:
//
// 1. We need a type `X` to use as the argument for the
// parameter `L`. We might supply this explicitly, or
// via inference.
//
// 2. We need a *proof* that whatever `X` we chose conforms
// to the `ILight` interface.
//
// The easiest way to make such a proof is by construction,
// and a `Witness` represents such a constructive proof.
// Conceptually a proposition like `X : ILight` can be
// seen as a type, and witness prooving that proposition
// is a value of that type.
//
// We construct and store witnesses explicitly during
// semantic checking because they can help us with
// generating downstream code. By following the structure
// of a witness (the structure of a proof) we can, e.g.,
// navigate from the knowledge that `X : ILight` to
// the concrete declarations that provide the implementation
// of `ILight` for `X`.
//
class Witness : public Val
{
SLANG_ABSTRACT_AST_CLASS(Witness)
};
// A witness that one type is a subtype of another
// (where by "subtype" we include both inheritance
// relationships and type-conforms-to-interface relationships)
//
// TODO: we may need to tease those apart.
class SubtypeWitness : public Witness
{
SLANG_ABSTRACT_AST_CLASS(SubtypeWitness)
Val* _resolveImplOverride();
Type* getSub() { return as<Type>(getOperand(0)); }
Type* getSup() { return as<Type>(getOperand(1)); }
ConversionCost _getOverloadResolutionCostOverride();
ConversionCost getOverloadResolutionCost();
};
class TypeEqualityWitness : public SubtypeWitness
{
SLANG_AST_CLASS(TypeEqualityWitness)
TypeEqualityWitness(Type* subType, Type* supType)
{
setOperands(subType, supType);
}
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
// A witness that one type is a subtype of another
// because some in-scope declaration says so
class DeclaredSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(DeclaredSubtypeWitness)
DeclRef<Decl> getDeclRef()
{
return as<DeclRefBase>(getOperand(2));
}
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
DeclaredSubtypeWitness(Type* inSub, Type* inSup, DeclRef<Decl> inDeclRef)
{
setOperands(inSub, inSup, inDeclRef);
}
ConversionCost _getOverloadResolutionCostOverride();
};
// A witness that `sub : sup` because `sub : mid` and `mid : sup`
class TransitiveSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(TransitiveSubtypeWitness)
// Witness that `sub : mid`
SubtypeWitness* getSubToMid()
{
return as<SubtypeWitness>(getOperand(2));
}
// Witness that `mid : sup`
SubtypeWitness* getMidToSup()
{
return as<SubtypeWitness>(getOperand(3));
}
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
TransitiveSubtypeWitness(Type* subType, Type* supType, SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup)
{
setOperands(subType, supType, inSubToMid, inMidToSup);
}
ConversionCost _getOverloadResolutionCostOverride();
};
// A witness that `sub : sup` because `sub` was wrapped into
// an existential of type `sup`.
class ExtractExistentialSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(ExtractExistentialSubtypeWitness)
// The declaration of the existential value that has been opened
DeclRef<VarDeclBase> getDeclRef() { return as<DeclRefBase>(getOperand(2)); }
ExtractExistentialSubtypeWitness(Type* inSub, Type* inSup, DeclRef<Decl> inDeclRef)
{
setOperands(inSub, inSup, inDeclRef);
}
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
/// A witness of the fact that a user provided "__Dynamic" type argument is a
/// subtype to the existential type parameter.
class DynamicSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(DynamicSubtypeWitness)
DynamicSubtypeWitness(Type* inSub, Type* inSup)
{
setOperands(inSub, inSup);
}
};
/// A witness that `T : L & R` because `T : L` and `T : R`
class ConjunctionSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(ConjunctionSubtypeWitness)
// At the operational level, this class of witness is
// an operation that takes two witness tables `leftWitness`
// and `rightWitness`, and forms a pair/tuple of
// `(leftWitness, rightWitness)`.
static const int kComponentCount = 2;
ConjunctionSubtypeWitness(Type* inSub, Type* inSup, SubtypeWitness* left, SubtypeWitness* right)
{
setOperands(inSub, inSup, left, right);
}
SubtypeWitness* getLeftWitness() const { return as<SubtypeWitness>(getOperand(2)); }
SubtypeWitness* getRightWitness() const { return as<SubtypeWitness>(getOperand(3)); }
Count getComponentCount() const { return 2; }
SubtypeWitness* getComponentWitness(Index index) const
{
SLANG_ASSERT(index >= 0 && index < kComponentCount);
return as<SubtypeWitness>(getOperand(2 + index));
}
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
ConversionCost _getOverloadResolutionCostOverride();
};
/// A witness that `T <: L` or `T <: R` because `T <: L&R`
class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(ExtractFromConjunctionSubtypeWitness)
// At the operational level, this class of witness is
// an operation that takes a pair/tuple of witness tables
// `(leftWtiness, rightWitness)` and extracts one of the
// elements of it.
/// Witness that `T < L & R`
SubtypeWitness* getConjunctionWitness() { return as<SubtypeWitness>(getOperand(2)); };
ExtractFromConjunctionSubtypeWitness(Type* inSub, Type* inSup, SubtypeWitness* witness, int index)
{
setOperands(inSub, inSup, witness, index);
}
/// The zero-based index of the super-type we care about in the conjunction
///
/// If `conjunctionWitness` is `T < L & R` then this index should be zero if
/// we want to represent `T < L` and one if we want `T < R`.
///
int getIndexInConjunction() { return (int)getIntConstOperand(3); };
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
ConversionCost _getOverloadResolutionCostOverride();
};
/// A value that represents a modifier attached to some other value
class ModifierVal : public Val
{
SLANG_AST_CLASS(ModifierVal)
Val* _resolveImplOverride() { return this; }
};
class TypeModifierVal : public ModifierVal
{
SLANG_AST_CLASS(TypeModifierVal)
};
class ResourceFormatModifierVal : public TypeModifierVal
{
SLANG_AST_CLASS(ResourceFormatModifierVal)
};
class UNormModifierVal : public ResourceFormatModifierVal
{
SLANG_AST_CLASS(UNormModifierVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
class SNormModifierVal : public ResourceFormatModifierVal
{
SLANG_AST_CLASS(SNormModifierVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
class NoDiffModifierVal : public TypeModifierVal
{
SLANG_AST_CLASS(NoDiffModifierVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
/// Represents the result of differentiating a function.
class DifferentiateVal : public Val
{
SLANG_AST_CLASS(DifferentiateVal)
DifferentiateVal(DeclRef<Decl> inFunc)
{
setOperands(inFunc);
}
DeclRef<Decl> getFunc() { return as<DeclRefBase>(getOperand(0)); }
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
};
class ForwardDifferentiateVal : public DifferentiateVal
{
SLANG_AST_CLASS(ForwardDifferentiateVal)
ForwardDifferentiateVal(DeclRef<Decl> inFunc)
: DifferentiateVal(inFunc)
{}
};
class BackwardDifferentiateVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiateVal)
BackwardDifferentiateVal(DeclRef<Decl> inFunc)
: DifferentiateVal(inFunc)
{}
};
class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal)
BackwardDifferentiateIntermediateTypeVal(DeclRef<Decl> inFunc)
: DifferentiateVal(inFunc)
{}
};
class BackwardDifferentiatePrimalVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiatePrimalVal)
BackwardDifferentiatePrimalVal(DeclRef<Decl> inFunc)
: DifferentiateVal(inFunc)
{}
};
class BackwardDifferentiatePropagateVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiatePropagateVal)
BackwardDifferentiatePropagateVal(DeclRef<Decl> inFunc)
: DifferentiateVal(inFunc)
{}
};
template<typename F>
void SubstitutionSet::forEachGenericSubstitution(F func) const
{
if (!declRef)
return;
for (auto subst = declRef; subst; subst = subst->getBase())
{
if (auto genSubst = as<GenericAppDeclRef>(subst))
func(genSubst->getGenericDecl(), genSubst->getArgs());
}
}
template<typename F>
void SubstitutionSet::forEachSubstitutionArg(F func) const
{
if (!declRef)
return;
for (auto subst = declRef; subst; subst = subst->getBase())
{
if (auto genSubst = as<GenericAppDeclRef>(subst))
{
for (auto arg : genSubst->getArgs())
func(arg);
}
else if (auto thisSubst = as<LookupDeclRef>(subst))
{
func(thisSubst->getWitness()->getSub());
}
}
}
} // namespace Slang