https://github.com/JuliaLang/julia
Tip revision: 6ba0aaf0d9a592682155cc5d23f308b427e48c65 authored by Shuhei Kadowaki on 11 April 2024, 15:28:24 UTC
more test update
more test update
Tip revision: 6ba0aaf
inferenceresult.jl
# This file is a part of Julia. License is MIT: https://julialang.org/license
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance)
(; def, specTypes) = mi
return most_general_argtypes(isa(def, Method) ? def : nothing, specTypes)
end
struct SimpleArgtypes
argtypes::Vector{Any}
end
# Like `SimpleArgtypes`, but allows the argtypes to be wider than the current call.
# As a result, it is not legal to refine the cache result with information more
# precise than was it deducible from the `WidenedSimpleArgtypes`.
struct WidenedArgtypes
argtypes::Vector{Any}
end
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
simple_argtypes::Union{SimpleArgtypes, WidenedArgtypes},
cache_argtypes::Vector{Any})
(; argtypes) = simple_argtypes
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end
function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
nargtypes = length(given_argtypes)
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
for i = 1:nargtypes
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
# prefer the argtype we were given over the one computed from `mi`
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
!⊏(𝕃, given_argtype, cache_argtype))
# if the type information of this `PartialStruct` is less strict than
# declared method signature, narrow it down using `tmeet`
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
end
else
given_argtypes[i] = cache_argtype
end
end
return given_argtypes
end
function is_argtype_match(𝕃::AbstractLattice,
@nospecialize(given_argtype),
@nospecialize(cache_argtype),
overridden_by_const::Bool)
if is_forwardable_argtype(𝕃, given_argtype)
return is_lattice_equal(𝕃, given_argtype, cache_argtype)
else
return !overridden_by_const
end
end
va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
def = mi.def::Method
isva = def.isva
nargs = Int(def.nargs)
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs-isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
end
isva_given_argtypes[nargs] = tuple_tfunc(𝕃, given_argtypes[last:end])
va_handler!(isva_given_argtypes, last)
end
return isva_given_argtypes
end
@assert length(given_argtypes) == nargs "invalid `given_argtypes` for `mi`"
return given_argtypes
end
function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
toplevel = method === nothing
isva = !toplevel && method.isva
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
nargs::Int = toplevel ? 0 : method.nargs
cache_argtypes = Vector{Any}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialStruct` instance.
mi_argtypes_length = length(mi_argtypes)
if !toplevel && isva
if specTypes::Type == Tuple
mi_argtypes = Any[Any for i = 1:nargs]
if nargs > 1
mi_argtypes[end] = Tuple
end
vargtype = Tuple
else
if nargs > mi_argtypes_length
va = mi_argtypes[mi_argtypes_length]
if isvarargtype(va)
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
vargtype = Tuple{new_va}
else
vargtype = Tuple{}
end
else
vargtype_elements = Any[]
for i in nargs:mi_argtypes_length
p = mi_argtypes[i]
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
end
for i in 1:length(vargtype_elements)
atyp = vargtype_elements[i]
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
vargtype_elements[i] = Const(atyp.instance)
elseif isconstType(atyp)
vargtype_elements[i] = Const(atyp.parameters[1])
end
end
vargtype = tuple_tfunc(fallback_lattice, vargtype_elements)
end
end
cache_argtypes[nargs] = vargtype
nargs -= 1
end
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
# type info as we go (where possible). Note that if we're dealing with a varargs method,
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
# we don't overwrite the result of that work here).
if mi_argtypes_length > 0
tail_index = nargtypes = min(mi_argtypes_length, nargs)
local lastatype
for i = 1:nargtypes
atyp = mi_argtypes[i]
if i == nargtypes && isvarargtype(atyp)
atyp = unwrapva(atyp)
tail_index -= 1
end
atyp = unwraptv(atyp)
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
atyp = Const(atyp.instance)
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
i == nargtypes && (lastatype = atyp)
cache_argtypes[i] = atyp
end
for i = (tail_index+1):nargs
cache_argtypes[i] = lastatype
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
end
return cache_argtypes
end
# eliminate free `TypeVar`s in order to make the life much easier down the road:
# at runtime only `Type{...}::DataType` can contain invalid type parameters, and other
# malformed types here are user-constructed type arguments given at an inference entry
# so this function will replace only the malformed `Type{...}::DataType` with `Type`
# and simply replace other possibilities with `Any`
function elim_free_typevars(@nospecialize t)
if has_free_typevars(t)
return isType(t) ? Type : Any
else
return t
end
end
function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any},
cache::Vector{InferenceResult})
method = mi.def::Method
nargtypes = length(given_argtypes)
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
for cached_result in cache
cached_result.linfo === mi || @goto next_cache
cache_argtypes = cached_result.argtypes
@assert length(cache_argtypes) == nargtypes "invalid `cache_argtypes` for `mi`"
cache_overridden_by_const = cached_result.overridden_by_const::BitVector
for i in 1:nargtypes
if !is_argtype_match(𝕃, given_argtypes[i], cache_argtypes[i], cache_overridden_by_const[i])
@goto next_cache
end
end
return cached_result
@label next_cache
end
return nothing
end