https://github.com/shader-slang/slang
Tip revision: 259a015feb9d4ab65e8fbba32f6c777e92780cc7 authored by Yong He on 23 March 2023, 04:16:35 UTC
Type legalization and autodiff bug fixes. (#2722)
Type legalization and autodiff bug fixes. (#2722)
Tip revision: 259a015
slang-ast-val.h
// slang-ast-val.h
#pragma once
#include "slang-ast-base.h"
namespace Slang {
// Syntax class definitions for compile-time values.
// A compile-time integer (may not have a specific concrete value)
class IntVal : public Val
{
SLANG_ABSTRACT_AST_CLASS(IntVal)
Type* type;
IntVal(Type* inType)
: type(inType)
{}
};
// Trivial case of a value that is just a constant integer
class ConstantIntVal : public IntVal
{
SLANG_AST_CLASS(ConstantIntVal)
IntegerLiteralValue value;
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
protected:
ConstantIntVal(Type* inType, IntegerLiteralValue inValue)
: IntVal(inType), value(inValue)
{}
};
// The logical "value" of a reference to a generic value parameter
class GenericParamIntVal : public IntVal
{
SLANG_AST_CLASS(GenericParamIntVal)
DeclRef<VarDeclBase> declRef;
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
GenericParamIntVal(Type* inType, VarDeclBase* inDecl, Substitutions* inSubst)
: IntVal(inType), declRef(inDecl, inSubst)
{}
protected:
GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef)
: IntVal(inType), declRef(inDeclRef)
{}
};
// An compile time int val as result of some general computation.
class FuncCallIntVal : public IntVal
{
SLANG_AST_CLASS(FuncCallIntVal)
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
DeclRef<Decl> funcDeclRef;
Type* funcType;
List<IntVal*> args;
FuncCallIntVal(Type* inType) : IntVal(inType) {}
static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink);
};
class WitnessLookupIntVal : public IntVal
{
SLANG_AST_CLASS(WitnessLookupIntVal)
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
SubtypeWitness* witness;
Decl* key;
WitnessLookupIntVal(Type* inType) : IntVal(inType) {}
static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key);
static Val* tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type);
};
// polynomial expression "2*a*b^3 + 1" will be represented as:
// { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] }
class PolynomialIntValFactor : public NodeBase
{
SLANG_AST_CLASS(PolynomialIntValFactor)
public:
IntVal* param;
IntegerLiteralValue power;
// for sorting only.
bool operator<(const PolynomialIntValFactor& other) const
{
if (auto thisGenParam = as<GenericParamIntVal>(param))
{
if (auto thatGenParam = as<GenericParamIntVal>(other.param))
{
if (thisGenParam->equalsVal(thatGenParam))
return power < other.power;
else
return thisGenParam->declRef.decl < thatGenParam->declRef.decl;
}
else
{
return true;
}
}
else
{
if (auto thatGenParam = as<GenericParamIntVal>(other.param))
{
return false;
}
return param == other.param ? power < other.power : param < other.param;
}
}
// for sorting only.
bool operator==(const PolynomialIntValFactor& other) const
{
if (auto thisGenParam = as<GenericParamIntVal>(param))
{
if (auto thatGenParam = as<GenericParamIntVal>(other.param))
{
if (param->equalsVal(other.param) && power == other.power)
return true;
}
return false;
}
return power == other.power && param == other.param;
}
bool equals(const PolynomialIntValFactor& other) const
{
return power == other.power && param->equalsVal(other.param);
}
};
class PolynomialIntValTerm : public NodeBase
{
SLANG_AST_CLASS(PolynomialIntValTerm)
public:
IntegerLiteralValue constFactor;
List<PolynomialIntValFactor*> paramFactors;
bool canCombineWith(const PolynomialIntValTerm& other) const
{
if (paramFactors.getCount() != other.paramFactors.getCount())
return false;
for (Index i = 0; i < paramFactors.getCount(); i++)
{
if (!paramFactors[i]->equals(*other.paramFactors[i]))
return false;
}
return true;
}
bool operator<(const PolynomialIntValTerm& other) const
{
if (constFactor < other.constFactor)
return true;
else if (constFactor == other.constFactor)
{
for (Index i = 0; i < paramFactors.getCount(); i++)
{
if (i >= other.paramFactors.getCount())
return false;
if (*(paramFactors[i]) < *(other.paramFactors[i]))
return true;
if (*(paramFactors[i]) == *(other.paramFactors[i]))
{
}
else
{
return false;
}
}
}
return false;
}
};
class PolynomialIntVal : public IntVal
{
SLANG_AST_CLASS(PolynomialIntVal)
public:
List<PolynomialIntValTerm*> terms;
IntegerLiteralValue constantTerm = 0;
bool isConstant() { return terms.getCount() == 0; }
// Canonicalize the polynomial. If the polynomial can be simplified to a constant or a genericparam,
// the method returns the value simplified to.
// Otherwise, in-place modifications are performed and returns this.
IntVal* canonicalize(ASTBuilder* builder);
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
static PolynomialIntVal* neg(ASTBuilder* astBuilder, IntVal* base);
static PolynomialIntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
static PolynomialIntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
static PolynomialIntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1);
PolynomialIntVal(Type* inType) : IntVal(inType) {}
};
/// An unknown integer value indicating an erroneous sub-expression
class ErrorIntVal : public IntVal
{
SLANG_AST_CLASS(ErrorIntVal)
ErrorIntVal(Type* inType) : IntVal(inType) {}
// TODO: We should probably eventually just have an `ErrorVal` here
// and have all `Val`s that represent ordinary values hold their
// `Type` so that we can have an `ErrorVal` of any type.
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
// A witness to the fact that some proposition is true, encoded
// at the level of the type system.
//
// Given a generic like:
//
// void example<L>(L light)
// where L : ILight
// { ... }
//
// a call to `example()` needs two things for us to be sure
// it is valid:
//
// 1. We need a type `X` to use as the argument for the
// parameter `L`. We might supply this explicitly, or
// via inference.
//
// 2. We need a *proof* that whatever `X` we chose conforms
// to the `ILight` interface.
//
// The easiest way to make such a proof is by construction,
// and a `Witness` represents such a constructive proof.
// Conceptually a proposition like `X : ILight` can be
// seen as a type, and witness prooving that proposition
// is a value of that type.
//
// We construct and store witnesses explicitly during
// semantic checking because they can help us with
// generating downstream code. By following the structure
// of a witness (the structure of a proof) we can, e.g.,
// navigate from the knowledge that `X : ILight` to
// the concrete declarations that provide the implementation
// of `ILight` for `X`.
//
class Witness : public Val
{
SLANG_ABSTRACT_AST_CLASS(Witness)
};
// A witness that one type is a subtype of another
// (where by "subtype" we include both inheritance
// relationships and type-conforms-to-interface relationships)
//
// TODO: we may need to tease those apart.
class SubtypeWitness : public Witness
{
SLANG_ABSTRACT_AST_CLASS(SubtypeWitness)
Type* sub = nullptr;
Type* sup = nullptr;
};
class TypeEqualityWitness : public SubtypeWitness
{
SLANG_AST_CLASS(TypeEqualityWitness)
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
// A witness that one type is a subtype of another
// because some in-scope declaration says so
class DeclaredSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(DeclaredSubtypeWitness)
DeclRef<Decl> declRef;
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
DeclaredSubtypeWitness(Type* inSub, Type* inSup, Decl* decl, Substitutions* subst)
: declRef(decl, subst)
{
sub = inSub;
sup = inSup;
}
};
// A witness that `sub : sup` because `sub : mid` and `mid : sup`
class TransitiveSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(TransitiveSubtypeWitness)
// Witness that `sub : mid`
SubtypeWitness* subToMid = nullptr;
// Witness that `mid : sup`
SubtypeWitness* midToSup = nullptr;
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
TransitiveSubtypeWitness(SubtypeWitness* inSubToMid, SubtypeWitness* inMidToSup)
: subToMid(inSubToMid), midToSup(inMidToSup)
{}
};
// A witness taht `sub : sup` because `sub` was wrapped into
// an existential of type `sup`.
class ExtractExistentialSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(ExtractExistentialSubtypeWitness)
// The declaration of the existential value that has been opened
DeclRef<VarDeclBase> declRef;
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
// A witness that `sub : sup`, because `sub` is a tagged union
// of the form `A | B | C | ...` and each of `A : sup`,
// `B : sup`, `C : sup`, etc.
//
class TaggedUnionSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(TaggedUnionSubtypeWitness)
// Witnesses that each of the "case" types in the union
// is a subtype of `sup`.
//
List<Val*> caseWitnesses;
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
/// A witness of the fact that `ThisType(someInterface) : someInterface`
class ThisTypeSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(ThisTypeSubtypeWitness)
};
/// A witness of the fact that a user provided "__Dynamic" type argument is a
/// subtype to the existential type parameter.
class DynamicSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(DynamicSubtypeWitness)
};
/// A witness that `T : L & R` because `T : L` and `T : R`
class ConjunctionSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(ConjunctionSubtypeWitness)
/// Witness that `sub : sup->left`
Val* leftWitness;
/// Witness that `sub : sup->right`
Val* rightWitness;
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
/// A witness that `T : X` because `T : X & Y` or `T : Y & X`
class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness
{
SLANG_AST_CLASS(ExtractFromConjunctionSubtypeWitness)
/// Witness that `T : L & R` for some `R`
Val* conjunctionWitness;
/// The zero-based index of the super-type we care about in the conjunction
///
/// If `conjunctionWitness` is `T : X & Y` then this index should be zero if
/// we want to represent `T : X` and one if we want `T : Y`.
///
int indexInConjunction;
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
/// A value that represents a modifier attached to some other value
class ModifierVal : public Val
{
SLANG_AST_CLASS(ModifierVal)
bool _equalsValOverride(Val* val);
HashCode _getHashCodeOverride();
};
class TypeModifierVal : public ModifierVal
{
SLANG_AST_CLASS(TypeModifierVal)
};
class ResourceFormatModifierVal : public TypeModifierVal
{
SLANG_AST_CLASS(ResourceFormatModifierVal)
};
class UNormModifierVal : public ResourceFormatModifierVal
{
SLANG_AST_CLASS(UNormModifierVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
class SNormModifierVal : public ResourceFormatModifierVal
{
SLANG_AST_CLASS(SNormModifierVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
class NoDiffModifierVal : public TypeModifierVal
{
SLANG_AST_CLASS(NoDiffModifierVal)
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
/// Represents the result of differentiating a function.
class DifferentiateVal : public Val
{
SLANG_AST_CLASS(DifferentiateVal)
DeclRef<Decl> func;
bool _equalsValOverride(Val* val);
void _toTextOverride(StringBuilder& out);
HashCode _getHashCodeOverride();
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
class ForwardDifferentiateVal : public DifferentiateVal
{
SLANG_AST_CLASS(ForwardDifferentiateVal)
};
class BackwardDifferentiateVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiateVal)
};
class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal)
};
class BackwardDifferentiatePrimalVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiatePrimalVal)
};
class BackwardDifferentiatePropagateVal : public DifferentiateVal
{
SLANG_AST_CLASS(BackwardDifferentiatePropagateVal)
};
} // namespace Slang