https://github.com/JuliaLang/julia
Tip revision: 494420c6b65ae70b0ea7c5a7b95108c58f6005f1 authored by Keno Fischer on 09 March 2024, 00:21:34 UTC
Don't strip hygienic scope in doc macro
Don't strip hygienic scope in doc macro
Tip revision: 494420c
inferencestate.jl
# This file is a part of Julia. License is MIT: https://julialang.org/license
# data structures
# ===============
mutable struct BitSetBoundedMinPrioritySet <: AbstractSet{Int}
elems::BitSet
min::Int
# Stores whether min is exact or a lower bound
# If exact, it is not set in elems
min_exact::Bool
max::Int
end
function BitSetBoundedMinPrioritySet(max::Int)
bs = BitSet()
bs.offset = 0
BitSetBoundedMinPrioritySet(bs, max+1, true, max)
end
@noinline function _advance_bsbmp!(bsbmp::BitSetBoundedMinPrioritySet)
@assert !bsbmp.min_exact
bsbmp.min = _bits_findnext(bsbmp.elems.bits, bsbmp.min)::Int
bsbmp.min < 0 && (bsbmp.min = bsbmp.max + 1)
bsbmp.min_exact = true
delete!(bsbmp.elems, bsbmp.min)
return nothing
end
function isempty(bsbmp::BitSetBoundedMinPrioritySet)
if bsbmp.min > bsbmp.max
return true
end
bsbmp.min_exact && return false
_advance_bsbmp!(bsbmp)
return bsbmp.min > bsbmp.max
end
function popfirst!(bsbmp::BitSetBoundedMinPrioritySet)
bsbmp.min_exact || _advance_bsbmp!(bsbmp)
m = bsbmp.min
m > bsbmp.max && throw(ArgumentError("BitSetBoundedMinPrioritySet must be non-empty"))
bsbmp.min = m+1
bsbmp.min_exact = false
return m
end
function push!(bsbmp::BitSetBoundedMinPrioritySet, idx::Int)
if idx <= bsbmp.min
if bsbmp.min_exact && bsbmp.min < bsbmp.max && idx != bsbmp.min
push!(bsbmp.elems, bsbmp.min)
end
bsbmp.min = idx
bsbmp.min_exact = true
return nothing
end
push!(bsbmp.elems, idx)
return nothing
end
function in(idx::Int, bsbmp::BitSetBoundedMinPrioritySet)
if bsbmp.min_exact && idx == bsbmp.min
return true
end
return idx in bsbmp.elems
end
iterate(bsbmp::BitSetBoundedMinPrioritySet, s...) = iterate(bsbmp.elems, s...)
function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
for val in itr
push!(bsbmp, val)
end
end
mutable struct TwoPhaseVectorView <: AbstractVector{Int}
const data::Vector{Int}
count::Int
const range::UnitRange{Int}
end
size(tpvv::TwoPhaseVectorView) = (tpvv.count,)
function getindex(tpvv::TwoPhaseVectorView, i::Int)
checkbounds(tpvv, i)
@inbounds tpvv.data[first(tpvv.range) + i - 1]
end
function push!(tpvv::TwoPhaseVectorView, v::Int)
tpvv.count += 1
tpvv.data[first(tpvv.range) + tpvv.count - 1] = v
return nothing
end
"""
mutable struct TwoPhaseDefUseMap
This struct is intended as a memory- and GC-pressure-efficient mechanism
for incrementally computing def-use maps. The idea is that the def-use map
is constructed into two passes over the IR. In the first, we simply count the
the number of uses, computing the number of uses for each def as well as the
total number of uses. In the second pass, we actually fill in the def-use
information.
The idea is that either of these two phases can be combined with other useful
work that needs to scan the instruction stream anyway, while avoiding the
significant allocation pressure of e.g. allocating an array for every SSA value
or attempting to dynamically move things around as new uses are discovered.
The def-use map is presented as a vector of vectors. For every def, indexing
into the map will return a vector of uses.
"""
mutable struct TwoPhaseDefUseMap <: AbstractVector{TwoPhaseVectorView}
ssa_uses::Vector{Int}
data::Vector{Int}
complete::Bool
end
function complete!(tpdum::TwoPhaseDefUseMap)
cumsum = 0
for i = 1:length(tpdum.ssa_uses)
this_val = cumsum + 1
cumsum += tpdum.ssa_uses[i]
tpdum.ssa_uses[i] = this_val
end
resize!(tpdum.data, cumsum)
fill!(tpdum.data, 0)
tpdum.complete = true
end
function TwoPhaseDefUseMap(nssas::Int)
ssa_uses = zeros(Int, nssas)
data = Int[]
complete = false
return TwoPhaseDefUseMap(ssa_uses, data, complete)
end
function count!(tpdum::TwoPhaseDefUseMap, arg::SSAValue)
@assert !tpdum.complete
tpdum.ssa_uses[arg.id] += 1
end
function kill_def_use!(tpdum::TwoPhaseDefUseMap, def::Int, use::Int)
if !tpdum.complete
tpdum.ssa_uses[def] -= 1
else
range = tpdum.ssa_uses[def]:(def == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[def + 1] - 1))
# TODO: Sorted
useidx = findfirst(idx->tpdum.data[idx] == use, range)
@assert useidx !== nothing
idx = range[useidx]
while idx < lastindex(range)
ndata = tpdum.data[idx+1]
ndata == 0 && break
tpdum.data[idx] = ndata
idx += 1
end
tpdum.data[idx] = 0
end
end
kill_def_use!(tpdum::TwoPhaseDefUseMap, def::SSAValue, use::Int) =
kill_def_use!(tpdum, def.id, use)
function getindex(tpdum::TwoPhaseDefUseMap, idx::Int)
@assert tpdum.complete
range = tpdum.ssa_uses[idx]:(idx == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[idx + 1] - 1))
# TODO: Make logarithmic
nelems = 0
for i in range
tpdum.data[i] == 0 && break
nelems += 1
end
return TwoPhaseVectorView(tpdum.data, nelems, range)
end
mutable struct LazyCFGReachability
ir::IRCode
reachability::CFGReachability
LazyCFGReachability(ir::IRCode) = new(ir)
end
function get!(x::LazyCFGReachability)
isdefined(x, :reachability) && return x.reachability
domtree = construct_domtree(x.ir)
return x.reachability = CFGReachability(x.ir.cfg, domtree)
end
mutable struct LazyGenericDomtree{IsPostDom}
ir::IRCode
domtree::GenericDomTree{IsPostDom}
LazyGenericDomtree{IsPostDom}(ir::IRCode) where {IsPostDom} = new{IsPostDom}(ir)
end
function get!(x::LazyGenericDomtree{IsPostDom}) where {IsPostDom}
isdefined(x, :domtree) && return x.domtree
return @timeit "domtree 2" x.domtree = IsPostDom ?
construct_postdomtree(x.ir) :
construct_domtree(x.ir)
end
const LazyDomtree = LazyGenericDomtree{false}
const LazyPostDomtree = LazyGenericDomtree{true}
# InferenceState
# ==============
"""
const VarTable = Vector{VarState}
The extended lattice that maps local variables to inferred type represented as `AbstractLattice`.
Each index corresponds to the `id` of `SlotNumber` which identifies each local variable.
Note that `InferenceState` will maintain multiple `VarTable`s at each SSA statement
to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}
const CACHE_MODE_NULL = 0x00 # not cached, without optimization
const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed
mutable struct TryCatchFrame
exct
scopet
const enter_idx::Int
scope_uses::Vector{Int}
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx)
end
mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
world::UInt
mod::Module
sptypes::Vector{VarState}
slottypes::Vector{Any}
src::CodeInfo
cfg::CFG
method_info::MethodInfo
#= intermediate states for local abstract interpretation =#
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handlers::Vector{TryCatchFrame}
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, exception stack) value at the pc
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Union{Nothing,Vector{Any}}}
stmt_info::Vector{CallInfo}
#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
callers_in_cycle::Vector{InferenceState}
dont_work_on_me::Bool
parent # ::Union{Nothing,AbsIntState}
#= results =#
result::InferenceResult # remember where to put the result
unreachable::BitSet # statements that were found to be statically unreachable
valid_worlds::WorldRange
bestguess #::Type
exc_bestguess
ipo_effects::Effects
#= flags =#
# Whether to restrict inference of abstract call sites to avoid excessive work
# Set by default for toplevel frame.
restrict_abstract_call_sites::Bool
cache_mode::UInt8 # TODO move this to InferenceResult?
insert_coverage::Bool
# The interpreter that created this inference state. Not looked at by
# NativeInterpreter. But other interpreters may use this to detect cycles
interp::AbstractInterpreter
# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult, src::CodeInfo, cache_mode::UInt8,
interp::AbstractInterpreter)
linfo = result.linfo
world = get_inference_world(interp)
if world == typemax(UInt)
error("Entering inference from a generated function with an invalid world")
end
def = linfo.def
mod = isa(def, Method) ? def.module : def
sptypes = sptypes_from_meth_instance(linfo)
code = src.code::Vector{Any}
cfg = compute_basic_blocks(code)
method_info = MethodInfo(src)
currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at, handlers = compute_trycatch(code, BitSet())
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
stmt_info = CallInfo[ NoCallInfo() for i = 1:nstmts ]
nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
argtypes = result.argtypes
nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
end
src.ssavaluetypes = ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
unreachable = BitSet()
pclimitations = IdSet{InferenceState}()
limitations = IdSet{InferenceState}()
cycle_backedges = Vector{Tuple{InferenceState,Int}}()
callers_in_cycle = Vector{InferenceState}()
dont_work_on_me = false
parent = nothing
valid_worlds = WorldRange(1, get_world_counter())
bestguess = Bottom
exc_bestguess = Bottom
ipo_effects = EFFECTS_TOTAL
insert_coverage = should_insert_coverage(mod, src.debuginfo)
if insert_coverage
ipo_effects = Effects(ipo_effects; effect_free = ALWAYS_FALSE)
end
if def isa Method
ipo_effects = Effects(ipo_effects; nonoverlayed=is_nonoverlayed(def))
end
restrict_abstract_call_sites = isa(def, Module)
# some more setups
InferenceParams(interp).unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)
!iszero(cache_mode & CACHE_MODE_LOCAL) && push!(get_inference_cache(interp), result)
this = new(
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)
# Apply generated function restrictions
if src.min_world != 1 || src.max_world != typemax(UInt)
# From generated functions
this.valid_worlds = WorldRange(src.min_world, src.max_world)
end
return this
end
end
is_nonoverlayed(m::Method) = !isdefined(m, :external_mt)
is_nonoverlayed(interp::AbstractInterpreter) = !isoverlayed(method_table(interp))
isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)
is_inferred(sv::InferenceState) = is_inferred(sv.result)
is_inferred(result::InferenceResult) = result.result !== nothing
was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND
function compute_trycatch(code::Vector{Any}, ip::BitSet)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
# 3: (expr) # == 1
# 3: (leave %1) # == 1
# 4: (expr) # == 0
# then we can find all `try`s by walking backwards from :enter statements,
# and all `catch`es by looking at the statement after the :enter
n = length(code)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill((0, 0), n)
handlers = TryCatchFrame[]
# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isa(stmt, EnterNode)
l = stmt.catch_dest
push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
if l != 0
handler_at[l] = (0, handler_id)
push!(ip, l)
end
end
end
# now forward those marks to all :leave statements
while true
# make progress on the active ip set
pc = _bits_findnext(ip.bits, 0)::Int
pc > n && break
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
delete!(ip, pc)
cur_stacks = handler_at[pc]
@assert cur_stacks != (0, 0) "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_stacks
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
handler_at[l] = cur_stacks
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
break
elseif isa(stmt, EnterNode)
l = stmt.catch_dest
# We assigned a handler number above. Here we just merge that
# with out current handler information.
if l != 0
handler_at[l] = (cur_stacks[1], handler_at[l][2])
end
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
elseif isa(stmt, Expr)
head = stmt.head
if head === :leave
l = 0
for j = 1:length(stmt.args)
arg = stmt.args[j]
if arg === nothing
continue
else
enter_stmt = code[(arg::SSAValue).id]
if enter_stmt === nothing
continue
end
@assert isa(enter_stmt, EnterNode) "malformed :leave"
end
l += 1
end
cur_hand = cur_stacks[1]
for i = 1:l
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
end
cur_stacks = (cur_hand, cur_stacks[2])
cur_stacks == (0, 0) && break
elseif head === :pop_exception
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
cur_stacks == (0, 0) && break
end
end
pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_stacks
handler_at[pc´] = cur_stacks
elseif !in(pc´, ip)
break # already visited
end
pc = pc´
end
end
@assert first(ip) == n + 1
return handler_at, handlers
end
# check if coverage mode is enabled
function should_insert_coverage(mod::Module, debuginfo::DebugInfo)
coverage_enabled(mod) && return true
JLOptions().code_coverage == 3 || return false
# path-specific coverage mode: if any line falls in a tracked file enable coverage for all
return _should_insert_coverage(debuginfo)
end
_should_insert_coverage(mod::Symbol) = is_file_tracked(mod)
_should_insert_coverage(mod::Method) = _should_insert_coverage(mod.file)
_should_insert_coverage(mod::MethodInstance) = _should_insert_coverage(mod.def)
_should_insert_coverage(mod::Module) = false
function _should_insert_coverage(info::DebugInfo)
linetable = info.linetable
linetable === nothing || (_should_insert_coverage(linetable) && return true)
_should_insert_coverage(info.def) && return true
return false
end
function InferenceState(result::InferenceResult, cache_mode::UInt8, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
world = get_inference_world(interp)
src = retrieve_code_info(result.linfo, world)
src === nothing && return nothing
maybe_validate_code(result.linfo, src, "lowered")
return InferenceState(result, src, cache_mode, interp)
end
InferenceState(result::InferenceResult, cache_mode::Symbol, interp::AbstractInterpreter) =
InferenceState(result, convert_cache_mode(cache_mode), interp)
InferenceState(result::InferenceResult, src::CodeInfo, cache_mode::Symbol, interp::AbstractInterpreter) =
InferenceState(result, src, convert_cache_mode(cache_mode), interp)
function convert_cache_mode(cache_mode::Symbol)
if cache_mode === :global
return CACHE_MODE_GLOBAL
elseif cache_mode === :local
return CACHE_MODE_LOCAL
elseif cache_mode === :volatile
return CACHE_MODE_VOLATILE
elseif cache_mode === :no
return CACHE_MODE_NULL
end
error("unexpected `cache_mode` is given")
end
"""
constrains_param(var::TypeVar, sig, covariant::Bool, type_constrains::Bool)
Check if `var` will be constrained to have a definite value
in any concrete leaftype subtype of `sig`.
It is used as a helper to determine whether type intersection is guaranteed to be able to
find a value for a particular type parameter.
A necessary condition for type intersection to not assign a parameter is that it only
appears in a `Union[All]` and during subtyping some other union component (that does not
constrain the type parameter) is selected.
The `type_constrains` flag determines whether Type{T} is considered to be constraining
`T`. This is not true in general, because of the existence of types with free type
parameters, however, some callers would like to ignore this corner case.
"""
function constrains_param(var::TypeVar, @nospecialize(typ), covariant::Bool, type_constrains::Bool=false)
typ === var && return true
while typ isa UnionAll
covariant && constrains_param(var, typ.var.ub, covariant, type_constrains) && return true
# typ.var.lb doesn't constrain var
typ = typ.body
end
if typ isa Union
# for unions, verify that both options would constrain var
ba = constrains_param(var, typ.a, covariant, type_constrains)
bb = constrains_param(var, typ.b, covariant, type_constrains)
(ba && bb) && return true
elseif typ isa DataType
# return true if any param constrains var
fc = length(typ.parameters)
if fc > 0
if typ.name === Tuple.name
# vararg tuple needs special handling
for i in 1:(fc - 1)
p = typ.parameters[i]
constrains_param(var, p, covariant, type_constrains) && return true
end
lastp = typ.parameters[fc]
vararg = unwrap_unionall(lastp)
if vararg isa Core.TypeofVararg && isdefined(vararg, :N)
constrains_param(var, vararg.N, covariant, type_constrains) && return true
# T = vararg.parameters[1] doesn't constrain var
else
constrains_param(var, lastp, covariant, type_constrains) && return true
end
else
if typ.name === typename(Type) && typ.parameters[1] === var && var.ub === Any
# Types with free type parameters are <: Type cause the typevar
# to be unconstrained because Type{T} with free typevars is illegal
return type_constrains
end
for i in 1:fc
p = typ.parameters[i]
constrains_param(var, p, false, type_constrains) && return true
end
end
end
end
return false
end
const EMPTY_SPTYPES = VarState[]
function sptypes_from_meth_instance(linfo::MethodInstance)
def = linfo.def
isa(def, Method) || return EMPTY_SPTYPES # toplevel
sig = def.sig
if isempty(linfo.sparam_vals)
isa(sig, UnionAll) || return EMPTY_SPTYPES
# linfo is unspecialized
spvals = Any[]
sig′ = sig
while isa(sig′, UnionAll)
push!(spvals, sig′.var)
sig′ = sig′.body
end
else
spvals = linfo.sparam_vals
end
nvals = length(spvals)
sptypes = Vector{VarState}(undef, nvals)
for i = 1:nvals
v = spvals[i]
if v isa TypeVar
temp = sig
for j = 1:i-1
temp = temp.body
end
vᵢ = (temp::UnionAll).var
sigtypes = (unwrap_unionall(temp)::DataType).parameters
for j = 1:length(sigtypes)
sⱼ = sigtypes[j]
if isType(sⱼ) && sⱼ.parameters[1] === vᵢ
# if this parameter came from `arg::Type{T}`,
# then `arg` is more precise than `Type{T} where lb<:T<:ub`
ty = fieldtype(linfo.specTypes, j)
@goto ty_computed
elseif (va = va_from_vatuple(sⱼ)) !== nothing
# if this parameter came from `::Tuple{.., Vararg{T,vᵢ}}`,
# then `vᵢ` is known to be `Int`
if isdefined(va, :N) && va.N === vᵢ
ty = Int
@goto ty_computed
end
end
end
ub = unwraptv_ub(v)
if has_free_typevars(ub)
ub = Any
end
lb = unwraptv_lb(v)
if has_free_typevars(lb)
lb = Bottom
end
if Any === ub && lb === Bottom
ty = Any
else
tv = TypeVar(v.name, lb, ub)
ty = UnionAll(tv, Type{tv})
end
@label ty_computed
undef = !(let sig=sig
# if the specialized signature `linfo.specTypes` doesn't contain any free
# type variables, we can use it for a more accurate analysis of whether `v`
# is constrained or not, otherwise we should use `def.sig` which always
# doesn't contain any free type variables
if !has_free_typevars(linfo.specTypes)
sig = linfo.specTypes
end
@assert !has_free_typevars(sig)
constrains_param(v, sig, #=covariant=#true)
end)
elseif isvarargtype(v)
# if this parameter came from `func(..., ::Vararg{T,v})`,
# so the type is known to be `Int`
ty = Int
undef = false
else
ty = Const(v)
undef = false
end
sptypes[i] = VarState(ty, undef)
end
return sptypes
end
function va_from_vatuple(@nospecialize(t))
@_foldable_meta
t = unwrap_unionall(t)
if isa(t, DataType)
n = length(t.parameters)
if n > 0
va = t.parameters[n]
if isvarargtype(va)
return va
end
end
end
return nothing
end
_topmod(sv::InferenceState) = _topmod(frame_module(sv))
function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.ssavaluetypes
old = ssavaluetypes[ssa_id]
if old === NOT_FOUND || !is_lattice_equal(𝕃ᵢ, new, old)
ssavaluetypes[ssa_id] = new
W = frame.ip
for r in frame.ssavalue_uses[ssa_id]
if was_reached(frame, r)
usebb = block_for_inst(frame.cfg, r)
# We're guaranteed to visit the statement if it's in the current
# basic block, since SSA values can only ever appear after their
# def.
if usebb != frame.currbb
push!(W, usebb)
end
end
end
end
return nothing
end
function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
update_valid_age!(caller, frame.valid_worlds)
backedge = (caller, currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(caller, frame.linfo)
return frame
end
function get_stmt_edges!(caller::InferenceState, currpc::Int=caller.currpc)
stmt_edges = caller.stmt_edges
edges = stmt_edges[currpc]
if edges === nothing
edges = stmt_edges[currpc] = []
end
return edges
end
function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
edges = frame.stmt_edges[currpc]
edges === nothing || empty!(edges)
return nothing
end
function print_callstack(sv::InferenceState)
print("=================== Callstack: ==================\n")
idx = 0
while sv !== nothing
print("[")
print(idx)
if !isa(sv.interp, NativeInterpreter)
print(", ")
print(typeof(sv.interp))
end
print("] ")
print(sv.linfo)
is_cached(sv) || print(" [uncached]")
println()
for cycle in sv.callers_in_cycle
print(' ', cycle.linfo)
println()
end
sv = sv.parent
idx += 1
end
print("================= End callstack ==================\n")
end
function narguments(sv::InferenceState, include_va::Bool=true)
def = sv.linfo.def
nargs = length(sv.result.argtypes)
if !include_va
nargs -= isa(def, Method) && def.isva
end
return nargs
end
# IRInterpretationState
# =====================
# TODO add `result::InferenceResult` and put the irinterp result into the inference cache?
mutable struct IRInterpretationState
const method_info::MethodInfo
const ir::IRCode
const mi::MethodInstance
const world::UInt
curridx::Int
const argtypes_refined::Vector{Bool}
const sptypes::Vector{VarState}
const tpdum::TwoPhaseDefUseMap
const ssa_refined::BitSet
const lazyreachability::LazyCFGReachability
valid_worlds::WorldRange
const edges::Vector{Any}
parent # ::Union{Nothing,AbsIntState}
function IRInterpretationState(interp::AbstractInterpreter,
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
world::UInt, min_world::UInt, max_world::UInt)
curridx = 1
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(given_argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(optimizer_lattice(interp), given_argtypes, mi)
argtypes_refined = Bool[!⊑(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
for i = 1:length(given_argtypes)]
empty!(ir.argtypes)
append!(ir.argtypes, given_argtypes)
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
ssa_refined = BitSet()
lazyreachability = LazyCFGReachability(ir)
valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world)
edges = Any[]
parent = nothing
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, valid_worlds, edges, parent)
end
end
function IRInterpretationState(interp::AbstractInterpreter,
code::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt)
@assert code.def === mi
src = @atomic :monotonic code.inferred
if isa(src, String)
src = _uncompressed_ir(code, src)
else
isa(src, CodeInfo) || return nothing
end
method_info = MethodInfo(src)
ir = inflate_ir(src, mi)
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
code.min_world, code.max_world)
end
# AbsIntState
# ===========
const AbsIntState = Union{InferenceState,IRInterpretationState}
frame_instance(sv::InferenceState) = sv.linfo
frame_instance(sv::IRInterpretationState) = sv.mi
function frame_module(sv::AbsIntState)
mi = frame_instance(sv)
def = mi.def
isa(def, Module) && return def
return def.module
end
frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState}
frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}
is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const)
is_constproped(::IRInterpretationState) = true
is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL)
is_cached(::IRInterpretationState) = false
method_info(sv::InferenceState) = sv.method_info
method_info(sv::IRInterpretationState) = sv.method_info
propagate_inbounds(sv::AbsIntState) = method_info(sv).propagate_inbounds
method_for_inference_limit_heuristics(sv::AbsIntState) = method_info(sv).method_for_inference_limit_heuristics
frame_world(sv::InferenceState) = sv.world
frame_world(sv::IRInterpretationState) = sv.world
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
callers_in_cycle(sv::IRInterpretationState) = ()
function is_effect_overridden(sv::AbsIntState, effect::Symbol)
if is_effect_overridden(frame_instance(sv), effect)
return true
elseif is_effect_overridden(decode_statement_effects_override(sv), effect)
return true
end
return false
end
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
def = linfo.def
return isa(def, Method) && is_effect_overridden(def, effect)
end
is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect)
is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect)
has_conditional(𝕃::AbstractLattice, ::InferenceState) = has_conditional(𝕃)
has_conditional(::AbstractLattice, ::IRInterpretationState) = false
# work towards converging the valid age range for sv
function update_valid_age!(sv::AbsIntState, valid_worlds::WorldRange)
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
@assert sv.world in valid_worlds "invalid age range update"
return valid_worlds
end
"""
AbsIntStackUnwind(sv::AbsIntState)
Iterate through all callers of the given `AbsIntState` in the abstract interpretation stack
(including the given `AbsIntState` itself), visiting children before their parents (i.e.
ascending the tree from the given `AbsIntState`).
Note that cycles may be visited in any order.
"""
struct AbsIntStackUnwind
sv::AbsIntState
end
iterate(unw::AbsIntStackUnwind) = (unw.sv, (unw.sv, 0))
function iterate(unw::AbsIntStackUnwind, (sv, cyclei)::Tuple{AbsIntState, Int})
# iterate through the cycle before walking to the parent
callers = callers_in_cycle(sv)
if callers !== () && cyclei < length(callers)
cyclei += 1
parent = callers[cyclei]
else
cyclei = 0
parent = frame_parent(sv)
end
parent === nothing && return nothing
return (parent, (parent, cyclei))
end
# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, mi::MethodInstance)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mi)
end
function add_backedge!(irsv::IRInterpretationState, mi::MethodInstance)
return push!(irsv.edges, mi)
end
function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), invokesig, mi)
end
function add_invoke_backedge!(irsv::IRInterpretationState, @nospecialize(invokesig::Type), mi::MethodInstance)
return push!(irsv.edges, invokesig, mi)
end
# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mt, typ)
end
function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospecialize(typ))
return push!(irsv.edges, mt, typ)
end
get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]
has_curr_ssaflag(sv::InferenceState, flag::UInt32) = has_flag(sv.src.ssaflags[sv.currpc], flag)
has_curr_ssaflag(sv::IRInterpretationState, flag::UInt32) = has_flag(sv.ir.stmts[sv.curridx][:flag], flag)
function set_curr_ssaflag!(sv::InferenceState, flag::UInt32, mask::UInt32=typemax(UInt32))
curr_flag = sv.src.ssaflags[sv.currpc]
sv.src.ssaflags[sv.currpc] = (curr_flag & ~mask) | flag
end
function set_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32, mask::UInt32=typemax(UInt32))
curr_flag = sv.ir.stmts[sv.curridx][:flag]
sv.ir.stmts[sv.curridx][:flag] = (curr_flag & ~mask) | flag
end
add_curr_ssaflag!(sv::InferenceState, flag::UInt32) = sv.src.ssaflags[sv.currpc] |= flag
add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32) = add_flag!(sv.ir.stmts[sv.curridx], flag)
sub_curr_ssaflag!(sv::InferenceState, flag::UInt32) = sv.src.ssaflags[sv.currpc] &= ~flag
sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32) = sub_flag!(sv.ir.stmts[sv.curridx], flag)
function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
if effects.effect_free === EFFECT_FREE_GLOBALLY
# This tracks the global effects
effects = Effects(effects; effect_free=ALWAYS_TRUE)
end
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end
merge_effects!(::AbstractInterpreter, ::IRInterpretationState, ::Effects) = return
decode_statement_effects_override(sv::AbsIntState) =
decode_statement_effects_override(get_curr_ssaflag(sv))
struct InferenceLoopState
sig
rt
effects::Effects
function InferenceLoopState(@nospecialize(sig), @nospecialize(rt), effects::Effects)
new(sig, rt, effects)
end
end
bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState) =
sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
bail_out_toplevel_call(::AbstractInterpreter, ::InferenceLoopState, ::IRInterpretationState) = false
bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState) =
state.rt === Any && !is_foldable(state.effects)
bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::IRInterpretationState) =
state.rt === Any && !is_foldable(state.effects)
bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState) =
state.rt === Any
bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::IRInterpretationState) =
state.rt === Any
function should_infer_this_call(interp::AbstractInterpreter, sv::InferenceState)
if InferenceParams(interp).unoptimize_throw_blocks
# Disable inference of calls in throw blocks, since we're unlikely to
# need their types. There is one exception however: If up until now, the
# function has not seen any side effects, we would like to make sure there
# aren't any in the throw block either to enable other optimizations.
if is_stmt_throw_block(get_curr_ssaflag(sv))
should_infer_for_effects(sv) || return false
end
end
return true
end
function should_infer_for_effects(sv::InferenceState)
def = sv.linfo.def
def isa Method || return false # toplevel frame will not be [semi-]concrete-evaluated
effects = sv.ipo_effects
override = decode_effects_override(def.purity)
effects.consistent === ALWAYS_FALSE && !is_effect_overridden(override, :consistent) && return false
effects.effect_free === ALWAYS_FALSE && !is_effect_overridden(override, :effect_free) && return false
!effects.terminates && !is_effect_overridden(override, :terminates_globally) && return false
return true
end
should_infer_this_call(::AbstractInterpreter, ::IRInterpretationState) = true
add_remark!(::AbstractInterpreter, ::InferenceState, remark) = return
add_remark!(::AbstractInterpreter, ::IRInterpretationState, remark) = return
function get_max_methods(interp::AbstractInterpreter, @nospecialize(f), sv::AbsIntState)
fmax = get_max_methods_for_func(f)
fmax !== nothing && return fmax
return get_max_methods(interp, sv)
end
function get_max_methods(interp::AbstractInterpreter, @nospecialize(f))
fmax = get_max_methods_for_func(f)
fmax !== nothing && return fmax
return get_max_methods(interp)
end
function get_max_methods(interp::AbstractInterpreter, sv::AbsIntState)
mmax = get_max_methods_for_module(sv)
mmax !== nothing && return mmax
return get_max_methods(interp)
end
get_max_methods(interp::AbstractInterpreter) = InferenceParams(interp).max_methods
function get_max_methods_for_func(@nospecialize(f))
if f !== nothing
fmm = typeof(f).name.max_methods
fmm !== UInt8(0) && return Int(fmm)
end
return nothing
end
get_max_methods_for_module(sv::AbsIntState) = get_max_methods_for_module(frame_module(sv))
function get_max_methods_for_module(mod::Module)
max_methods = ccall(:jl_get_module_max_methods, Cint, (Any,), mod) % Int
max_methods < 0 && return nothing
return max_methods
end