https://github.com/JuliaLang/julia
Raw File
Tip revision: 1e8a1fad151a3bdfe1bee3a361e0068a4d406a1d authored by Milan Bouchet-Valat on 28 May 2022, 20:15:24 UTC
Restore fast path for `Dict(d::Dict{K,V})` constructor
Tip revision: 1e8a1fa
optimize.jl
# This file is a part of Julia. License is MIT: https://julialang.org/license

#############
# constants #
#############

# The slot has uses that are not statically dominated by any assignment
# This is implied by `SLOT_USEDUNDEF`.
# If this is not set, all the uses are (statically) dominated by the defs.
# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA.
const SLOT_STATICUNDEF  = 1 # slot might be used before it is defined (structurally)
const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once
const SLOT_USEDUNDEF    = 32 # slot has uses that might raise UndefVarError
# const SLOT_CALLED      = 64

# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c

const IR_FLAG_NULL        = 0x00
# This statement is marked as @inbounds by user.
# Ff replaced by inlining, any contained boundschecks may be removed.
const IR_FLAG_INBOUNDS    = 0x01 << 0
# This statement is marked as @inline by user
const IR_FLAG_INLINE      = 0x01 << 1
# This statement is marked as @noinline by user
const IR_FLAG_NOINLINE    = 0x01 << 2
const IR_FLAG_THROW_BLOCK = 0x01 << 3
# This statement may be removed if its result is unused. In particular it must
# thus be both pure and effect free.
const IR_FLAG_EFFECT_FREE = 0x01 << 4

const TOP_TUPLE = GlobalRef(Core, :tuple)

#####################
# OptimizationState #
#####################

struct EdgeTracker
    edges::Vector{Any}
    valid_worlds::RefValue{WorldRange}
    EdgeTracker(edges::Vector{Any}, range::WorldRange) =
        new(edges, RefValue{WorldRange}(range))
end
EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt))

intersect!(et::EdgeTracker, range::WorldRange) =
    et.valid_worlds[] = intersect(et.valid_worlds[], range)

push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
function push!(et::EdgeTracker, ci::CodeInstance)
    intersect!(et, WorldRange(min_world(li), max_world(li)))
    push!(et, ci.def)
end

struct InliningState{S <: Union{EdgeTracker, Nothing}, MICache, I<:AbstractInterpreter}
    params::OptimizationParams
    et::S
    mi_cache::MICache # TODO move this to `OptimizationState` (as used by EscapeAnalysis as well)
    interp::I
end

is_source_inferred(@nospecialize(src::Union{CodeInfo, Vector{UInt8}})) =
    ccall(:jl_ir_flag_inferred, Bool, (Any,), src)

function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_flag::UInt8,
                         mi::MethodInstance, argtypes::Vector{Any})
    if isa(src, CodeInfo) || isa(src, Vector{UInt8})
        src_inferred = is_source_inferred(src)
        src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
        return src_inferred && src_inlineable ? src : nothing
    elseif src === nothing && is_stmt_inline(stmt_flag)
        # if this statement is forced to be inlined, make an additional effort to find the
        # inferred source in the local cache
        # we still won't find a source for recursive call because the "single-level" inlining
        # seems to be more trouble and complex than it's worth
        inf_result = cache_lookup(mi, argtypes, get_inference_cache(interp))
        inf_result === nothing && return nothing
        src = inf_result.src
        if isa(src, CodeInfo)
            src_inferred = is_source_inferred(src)
            return src_inferred ? src : nothing
        else
            return nothing
        end
    end
    return nothing
end

include("compiler/ssair/driver.jl")

mutable struct OptimizationState
    linfo::MethodInstance
    src::CodeInfo
    ir::Union{Nothing, IRCode}
    stmt_info::Vector{Any}
    mod::Module
    sptypes::Vector{Any} # static parameters
    slottypes::Vector{Any}
    inlining::InliningState
    function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
        s_edges = frame.stmt_edges[1]::Vector{Any}
        inlining = InliningState(params,
            EdgeTracker(s_edges, frame.valid_worlds),
            WorldView(code_cache(interp), frame.world),
            interp)
        return new(frame.linfo,
                   frame.src, nothing, frame.stmt_info, frame.mod,
                   frame.sptypes, frame.slottypes, inlining)
    end
    function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
        # prepare src for running optimization passes
        # if it isn't already
        nssavalues = src.ssavaluetypes
        if nssavalues isa Int
            src.ssavaluetypes = Any[ Any for i = 1:nssavalues ]
        else
            nssavalues = length(src.ssavaluetypes::Vector{Any})
        end
        nslots = length(src.slotflags)
        slottypes = src.slottypes
        if slottypes === nothing
            slottypes = Any[ Any for i = 1:nslots ]
        end
        stmt_info = Any[nothing for i = 1:nssavalues]
        # cache some useful state computations
        def = linfo.def
        mod = isa(def, Method) ? def.module : def
        # Allow using the global MI cache, but don't track edges.
        # This method is mostly used for unit testing the optimizer
        inlining = InliningState(params,
            nothing,
            WorldView(code_cache(interp), get_world_counter()),
            interp)
        return new(linfo,
                   src, nothing, stmt_info, mod,
                   sptypes_from_meth_instance(linfo), slottypes, inlining)
    end
end

function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
    src = retrieve_code_info(linfo)
    src === nothing && return nothing
    return OptimizationState(linfo, src, params, interp)
end

function ir_to_codeinf!(opt::OptimizationState)
    (; linfo, src) = opt
    optdef = linfo.def
    replace_code_newstyle!(src, opt.ir::IRCode, isa(optdef, Method) ? Int(optdef.nargs) : 0)
    opt.ir = nothing
    widen_all_consts!(src)
    src.inferred = true
    # finish updating the result struct
    validate_code_in_debug_mode(linfo, src, "optimized")
    return src
end

#########
# logic #
#########

_topmod(sv::OptimizationState) = _topmod(sv.mod)

is_stmt_inline(stmt_flag::UInt8)      = stmt_flag & IR_FLAG_INLINE      ≠ 0
is_stmt_noinline(stmt_flag::UInt8)    = stmt_flag & IR_FLAG_NOINLINE    ≠ 0
is_stmt_throw_block(stmt_flag::UInt8) = stmt_flag & IR_FLAG_THROW_BLOCK ≠ 0

# These affect control flow within the function (so may not be removed
# if there is no usage within the function), but don't affect the purity
# of the function as a whole.
function stmt_affects_purity(@nospecialize(stmt), ir)
    if isa(stmt, GotoNode) || isa(stmt, ReturnNode)
        return false
    end
    if isa(stmt, GotoIfNot)
        t = argextype(stmt.cond, ir)
        return !(t ⊑ Bool)
    end
    if isa(stmt, Expr)
        return stmt.head !== :loopinfo && stmt.head !== :enter
    end
    return true
end

"""
    stmt_effect_free(stmt, rt, src::Union{IRCode,IncrementalCompact})

Determine whether a `stmt` is "side-effect-free", i.e. may be removed if it has no uses.
"""
function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
    isa(stmt, PiNode) && return true
    isa(stmt, PhiNode) && return true
    isa(stmt, ReturnNode) && return false
    isa(stmt, GotoNode) && return false
    isa(stmt, GotoIfNot) && return false
    isa(stmt, Slot) && return false # Slots shouldn't occur in the IR at this point, but let's be defensive here
    isa(stmt, GlobalRef) && return isdefined(stmt.mod, stmt.name)
    if isa(stmt, Expr)
        (; head, args) = stmt
        if head === :static_parameter
            etyp = (isa(src, IRCode) ? src.sptypes : src.ir.sptypes)[args[1]::Int]
            # if we aren't certain enough about the type, it might be an UndefVarError at runtime
            return isa(etyp, Const)
        end
        if head === :call
            f = argextype(args[1], src)
            f = singleton_type(f)
            f === nothing && return false
            if isa(f, IntrinsicFunction)
                intrinsic_effect_free_if_nothrow(f) || return false
                return intrinsic_nothrow(f,
                        Any[argextype(args[i], src) for i = 2:length(args)])
            end
            contains_is(_PURE_BUILTINS, f) && return true
            # `get_binding_type` sets the type to Any if the binding doesn't exist yet
            if f === Core.get_binding_type
                length(args) == 3 || return false
                M, s = argextype(args[2], src), argextype(args[3], src)
                return get_binding_type_effect_free(M, s)
            end
            contains_is(_EFFECT_FREE_BUILTINS, f) || return false
            rt === Bottom && return false
            return _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt)
        elseif head === :new
            typ = argextype(args[1], src)
            # `Expr(:new)` of unknown type could raise arbitrary TypeError.
            typ, isexact = instanceof_tfunc(typ)
            isexact || return false
            isconcretedispatch(typ) || return false
            typ = typ::DataType
            fieldcount(typ) >= length(args) - 1 || return false
            for fld_idx in 1:(length(args) - 1)
                eT = argextype(args[fld_idx + 1], src)
                fT = fieldtype(typ, fld_idx)
                eT ⊑ fT || return false
            end
            return true
        elseif head === :foreigncall
            return foreigncall_effect_free(stmt, src)
        elseif head === :new_opaque_closure
            length(args) < 4 && return false
            typ = argextype(args[1], src)
            typ, isexact = instanceof_tfunc(typ)
            isexact || return false
            typ ⊑ Tuple || return false
            rt_lb = argextype(args[2], src)
            rt_ub = argextype(args[3], src)
            src = argextype(args[4], src)
            if !(rt_lb ⊑ Type && rt_ub ⊑ Type && src ⊑ Method)
                return false
            end
            return true
        elseif head === :isdefined || head === :the_exception || head === :copyast || head === :inbounds || head === :boundscheck
            return true
        else
            # e.g. :loopinfo
            return false
        end
    end
    return true
end

function foreigncall_effect_free(stmt::Expr, src::Union{IRCode,IncrementalCompact})
    args = stmt.args
    name = args[1]
    isa(name, QuoteNode) && (name = name.value)
    isa(name, Symbol) || return false
    ndims = alloc_array_ndims(name)
    if ndims !== nothing
        if ndims == 0
            return new_array_no_throw(args, src)
        else
            return alloc_array_no_throw(args, ndims, src)
        end
    end
    return false
end

function alloc_array_ndims(name::Symbol)
    if name === :jl_alloc_array_1d
        return 1
    elseif name === :jl_alloc_array_2d
        return 2
    elseif name === :jl_alloc_array_3d
        return 3
    elseif name === :jl_new_array
        return 0
    end
    return nothing
end

const FOREIGNCALL_ARG_START = 6

function alloc_array_no_throw(args::Vector{Any}, ndims::Int, src::Union{IRCode,IncrementalCompact})
    length(args) ≥ ndims+FOREIGNCALL_ARG_START || return false
    atype = instanceof_tfunc(argextype(args[FOREIGNCALL_ARG_START], src))[1]
    dims = Csize_t[]
    for i in 1:ndims
        dim = argextype(args[i+FOREIGNCALL_ARG_START], src)
        isa(dim, Const) || return false
        dimval = dim.val
        isa(dimval, Int) || return false
        push!(dims, reinterpret(Csize_t, dimval))
    end
    return _new_array_no_throw(atype, ndims, dims)
end

function new_array_no_throw(args::Vector{Any}, src::Union{IRCode,IncrementalCompact})
    length(args) ≥ FOREIGNCALL_ARG_START+1 || return false
    atype = instanceof_tfunc(argextype(args[FOREIGNCALL_ARG_START], src))[1]
    dims = argextype(args[FOREIGNCALL_ARG_START+1], src)
    isa(dims, Const) || return dims === Tuple{}
    dimsval = dims.val
    isa(dimsval, Tuple{Vararg{Int}}) || return false
    ndims = nfields(dimsval)
    isa(ndims, Int) || return false
    dims = Csize_t[reinterpret(Csize_t, dimval) for dimval in dimsval]
    return _new_array_no_throw(atype, ndims, dims)
end

function _new_array_no_throw(@nospecialize(atype), ndims::Int, dims::Vector{Csize_t})
    isa(atype, DataType) || return false
    eltype = atype.parameters[1]
    iskindtype(typeof(eltype)) || return false
    elsz = aligned_sizeof(eltype)
    return ccall(:jl_array_validate_dims, Cint,
        (Ptr{Csize_t}, Ptr{Csize_t}, UInt32, Ptr{Csize_t}, Csize_t),
        #=nel=#RefValue{Csize_t}(), #=tot=#RefValue{Csize_t}(), ndims, dims, elsz) == 0
end

"""
    argextype(x, src::Union{IRCode,IncrementalCompact}) -> t
    argextype(x, src::CodeInfo, sptypes::Vector{Any}) -> t

Return the type of value `x` in the context of inferred source `src`.
Note that `t` might be an extended lattice element.
Use `widenconst(t)` to get the native Julia type of `x`.
"""
argextype(@nospecialize(x), ir::IRCode, sptypes::Vector{Any} = ir.sptypes) =
    argextype(x, ir, sptypes, ir.argtypes)
function argextype(@nospecialize(x), compact::IncrementalCompact, sptypes::Vector{Any} = compact.ir.sptypes)
    isa(x, AnySSAValue) && return types(compact)[x]
    return argextype(x, compact, sptypes, compact.ir.argtypes)
end
argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{Any}) = argextype(x, src, sptypes, src.slottypes::Vector{Any})
function argextype(
    @nospecialize(x), src::Union{IRCode,IncrementalCompact,CodeInfo},
    sptypes::Vector{Any}, slottypes::Vector{Any})
    if isa(x, Expr)
        if x.head === :static_parameter
            return sptypes[x.args[1]::Int]
        elseif x.head === :boundscheck
            return Bool
        elseif x.head === :copyast
            return argextype(x.args[1], src, sptypes, slottypes)
        end
        @assert false "argextype only works on argument-position values"
    elseif isa(x, SlotNumber)
        return slottypes[x.id]
    elseif isa(x, TypedSlot)
        return x.typ
    elseif isa(x, SSAValue)
        return abstract_eval_ssavalue(x, src)
    elseif isa(x, Argument)
        return slottypes[x.n]
    elseif isa(x, QuoteNode)
        return Const(x.value)
    elseif isa(x, GlobalRef)
        return abstract_eval_global(x.mod, x.name)
    elseif isa(x, PhiNode)
        return Any
    elseif isa(x, PiNode)
        return x.typ
    else
        return Const(x)
    end
end
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]

struct ConstAPI
    val
    ConstAPI(@nospecialize val) = new(val)
end

"""
    finish(interp::AbstractInterpreter, opt::OptimizationState,
           params::OptimizationParams, ir::IRCode, caller::InferenceResult) -> analyzed::Union{Nothing,ConstAPI}

Post process information derived by Julia-level optimizations for later uses:
- computes "purity", i.e. side-effect-freeness
- computes inlining cost

In a case when the purity is proven, `finish` can return `ConstAPI` object wrapping the constant
value so that the runtime system will use the constant calling convention for the method calls.
"""
function finish(interp::AbstractInterpreter, opt::OptimizationState,
                params::OptimizationParams, ir::IRCode, caller::InferenceResult)
    (; src, linfo) = opt
    (; def, specTypes) = linfo

    analyzed = nothing # `ConstAPI` if this call can use constant calling convention
    force_noinline = _any(x::Expr -> x.head === :meta && x.args[1] === :noinline, ir.meta)

    # compute inlining and other related optimizations
    result = caller.result
    @assert !(result isa LimitedAccuracy)
    result = isa(result, InterConditional) ? widenconditional(result) : result
    if (isa(result, Const) || isconstType(result))
        proven_pure = false
        # must be proven pure to use constant calling convention;
        # otherwise we might skip throwing errors (issue #20704)
        # TODO: Improve this analysis; if a function is marked @pure we should really
        # only care about certain errors (e.g. method errors and type errors).
        if length(ir.stmts) < 15
            proven_pure = true
            for i in 1:length(ir.stmts)
                node = ir.stmts[i]
                stmt = node[:inst]
                if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir)
                    proven_pure = false
                    break
                end
            end
            if proven_pure
                for fl in src.slotflags
                    if (fl & SLOT_USEDUNDEF) != 0
                        proven_pure = false
                        break
                    end
                end
            end
        end

        if proven_pure
            # use constant calling convention
            # Do not emit `jl_fptr_const_return` if coverage is enabled
            # so that we don't need to add coverage support
            # to the `jl_call_method_internal` fast path
            # Still set pure flag to make sure `inference` tests pass
            # and to possibly enable more optimization in the future
            src.pure = true
            if isa(result, Const)
                val = result.val
                if is_inlineable_constant(val)
                    analyzed = ConstAPI(val)
                end
            else
                @assert isconstType(result)
                analyzed = ConstAPI(result.parameters[1])
            end
            force_noinline || (src.inlineable = true)
        end
    end

    opt.ir = ir

    # determine and cache inlineability
    union_penalties = false
    if !force_noinline
        sig = unwrap_unionall(specTypes)
        if isa(sig, DataType) && sig.name === Tuple.name
            for P in sig.parameters
                P = unwrap_unionall(P)
                if isa(P, Union)
                    union_penalties = true
                    break
                end
            end
        else
            force_noinline = true
        end
        if !src.inlineable && result === Bottom
            force_noinline = true
        end
    end
    if force_noinline
        src.inlineable = false
    elseif isa(def, Method)
        if src.inlineable && isdispatchtuple(specTypes)
            # obey @inline declaration if a dispatch barrier would not help
        else
            # compute the cost (size) of inlining this code
            cost_threshold = default = params.inline_cost_threshold
            if result ⊑ Tuple && !isconcretetype(widenconst(result))
                cost_threshold += params.inline_tupleret_bonus
            end
            # if the method is declared as `@inline`, increase the cost threshold 20x
            if src.inlineable
                cost_threshold += 19*default
            end
            # a few functions get special treatment
            if def.module === _topmod(def.module)
                name = def.name
                if name === :iterate || name === :unsafe_convert || name === :cconvert
                    cost_threshold += 4*default
                end
            end
            src.inlineable = inline_worthy(ir, params, union_penalties, cost_threshold)
        end
    end

    return analyzed
end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState,
                  params::OptimizationParams, caller::InferenceResult)
    @timeit "optimizer" ir = run_passes(opt.src, opt, caller)
    return finish(interp, opt, params, ir, caller)
end

using .EscapeAnalysis
import .EscapeAnalysis: EscapeState, ArgEscapeCache, is_ipo_profitable

"""
    cache_escapes!(caller::InferenceResult, estate::EscapeState)

Transforms escape information of call arguments of `caller`,
and then caches it into a global cache for later interprocedural propagation.
"""
cache_escapes!(caller::InferenceResult, estate::EscapeState) =
    caller.argescapes = ArgEscapeCache(estate)

function ipo_escape_cache(mi_cache::MICache) where MICache
    return function (linfo::Union{InferenceResult,MethodInstance})
        if isa(linfo, InferenceResult)
            argescapes = linfo.argescapes
        else
            codeinst = get(mi_cache, linfo, nothing)
            isa(codeinst, CodeInstance) || return nothing
            argescapes = codeinst.argescapes
        end
        return argescapes !== nothing ? argescapes::ArgEscapeCache : nothing
    end
end
null_escape_cache(linfo::Union{InferenceResult,MethodInstance}) = nothing

macro pass(name, expr)
    optimize_until = esc(:optimize_until)
    stage = esc(:__stage__)
    macrocall = :(@timeit $(esc(name)) $(esc(expr)))
    macrocall.args[2] = __source__  # `@timeit` may want to use it
    quote
        $macrocall
        matchpass($optimize_until, ($stage += 1), $(esc(name))) && $(esc(:(@goto __done__)))
    end
end

matchpass(optimize_until::Int, stage, _name) = optimize_until < stage
matchpass(optimize_until::String, _stage, name) = optimize_until == name
matchpass(::Nothing, _, _) = false

function run_passes(
    ci::CodeInfo,
    sv::OptimizationState,
    caller::InferenceResult,
    optimize_until = nothing,  # run all passes by default
)
    __stage__ = 1  # used by @pass
    # NOTE: The pass name MUST be unique for `optimize_until::AbstractString` to work
    @pass "convert"   ir = convert_to_ircode(ci, sv)
    @pass "slot2reg"  ir = slot2reg(ir, ci, sv)
    # TODO: Domsorting can produce an updated domtree - no need to recompute here
    @pass "compact 1" ir = compact!(ir)
    @pass "Inlining"  ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
    # @timeit "verify 2" verify_ir(ir)
    @pass "compact 2" ir = compact!(ir)
    @pass "SROA"      ir = sroa_pass!(ir)
    @pass "ADCE"      ir = adce_pass!(ir)
    @pass "type lift" ir = type_lift_pass!(ir)
    @pass "compact 3" ir = compact!(ir)
    if JLOptions().debug_level == 2
        @timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable))
    end
    @label __done__  # used by @pass
    return ir
end

function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
    linetable = ci.linetable
    if !isa(linetable, Vector{LineInfoNode})
        linetable = collect(LineInfoNode, linetable::Vector{Any})::Vector{LineInfoNode}
    end

    # check if coverage mode is enabled
    coverage = coverage_enabled(sv.mod)
    if !coverage && JLOptions().code_coverage == 3 # path-specific coverage mode
        for line in linetable
            if is_file_tracked(line.file)
                # if any line falls in a tracked file enable coverage for all
                coverage = true
                break
            end
        end
    end

    # Go through and add an unreachable node after every
    # Union{} call. Then reindex labels.
    code = copy_exprargs(ci.code)
    stmtinfo = sv.stmt_info
    codelocs = ci.codelocs
    ssavaluetypes = ci.ssavaluetypes::Vector{Any}
    ssaflags = ci.ssaflags
    meta = Expr[]
    idx = 1
    oldidx = 1
    ssachangemap = fill(0, length(code))
    labelchangemap = coverage ? fill(0, length(code)) : ssachangemap
    prevloc = zero(eltype(ci.codelocs))
    while idx <= length(code)
        codeloc = codelocs[idx]
        if coverage && codeloc != prevloc && codeloc != 0
            # insert a side-effect instruction before the current instruction in the same basic block
            insert!(code, idx, Expr(:code_coverage_effect))
            insert!(codelocs, idx, codeloc)
            insert!(ssavaluetypes, idx, Nothing)
            insert!(stmtinfo, idx, nothing)
            insert!(ssaflags, idx, IR_FLAG_NULL)
            ssachangemap[oldidx] += 1
            if oldidx < length(labelchangemap)
                labelchangemap[oldidx + 1] += 1
            end
            idx += 1
            prevloc = codeloc
        end
        if code[idx] isa Expr && ssavaluetypes[idx] === Union{}
            if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
                # insert unreachable in the same basic block after the current instruction (splitting it)
                insert!(code, idx + 1, ReturnNode())
                insert!(codelocs, idx + 1, codelocs[idx])
                insert!(ssavaluetypes, idx + 1, Union{})
                insert!(stmtinfo, idx + 1, nothing)
                insert!(ssaflags, idx + 1, ssaflags[idx])
                if oldidx < length(ssachangemap)
                    ssachangemap[oldidx + 1] += 1
                    coverage && (labelchangemap[oldidx + 1] += 1)
                end
                idx += 1
            end
        end
        idx += 1
        oldidx += 1
    end

    renumber_ir_elements!(code, ssachangemap, labelchangemap)

    for i = 1:length(code)
        code[i] = process_meta!(meta, code[i])
    end
    strip_trailing_junk!(ci, code, stmtinfo)
    types = Any[]
    stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags)
    cfg = compute_basic_blocks(code)
    return IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes)
end

function process_meta!(meta::Vector{Expr}, @nospecialize stmt)
    if isexpr(stmt, :meta) && length(stmt.args) ≥ 1
        push!(meta, stmt)
        return nothing
    end
    return stmt
end

function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState)
    # need `ci` for the slot metadata, IR for the code
    svdef = sv.linfo.def
    nargs = isa(svdef, Method) ? Int(svdef.nargs) : 0
    @timeit "domtree 1" domtree = construct_domtree(ir.cfg.blocks)
    defuse_insts = scan_slot_def_use(nargs, ci, ir.stmts.inst)
    @timeit "construct_ssa" ir = construct_ssa!(ci, ir, domtree, defuse_insts, sv.slottypes) # consumes `ir`
    return ir
end

## Computing the cost of a function body

# saturating sum (inputs are nonnegative), prevents overflow with typemax(Int) below
plus_saturate(x::Int, y::Int) = max(x, y, x+y)

# known return type
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))

function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
                        union_penalties::Bool, params::OptimizationParams, error_path::Bool = false)
    head = ex.head
    if is_meta_expr_head(head)
        return 0
    elseif head === :call
        farg = ex.args[1]
        ftyp = argextype(farg, src, sptypes)
        if ftyp === IntrinsicFunction && farg isa SSAValue
            # if this comes from code that was already inlined into another function,
            # Consts have been widened. try to recover in simple cases.
            farg = isa(src, CodeInfo) ? src.code[farg.id] : src.stmts[farg.id][:inst]
            if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter)
                ftyp = argextype(farg, src, sptypes)
            end
        end
        f = singleton_type(ftyp)
        if isa(f, IntrinsicFunction)
            iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1
            if !isassigned(T_IFUNC_COST, iidx)
                # unknown/unhandled intrinsic
                return params.inline_nonleaf_penalty
            end
            return T_IFUNC_COST[iidx]
        end
        if isa(f, Builtin)
            # The efficiency of operations like a[i] and s.b
            # depend strongly on whether the result can be
            # inferred, so check the type of ex
            if f === Core.getfield || f === Core.tuple || f === Core.getglobal
                # we might like to penalize non-inferrability, but
                # tuple iteration/destructuring makes that impossible
                # return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
                return 0
            elseif (f === Core.arrayref || f === Core.const_arrayref || f === Core.arrayset) && length(ex.args) >= 3
                atyp = argextype(ex.args[3], src, sptypes)
                return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
            elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
                return 1
            elseif f === Core.isa
                # If we're in a union context, we penalize type computations
                # on union types. In such cases, it is usually better to perform
                # union splitting on the outside.
                if union_penalties && isa(argextype(ex.args[2],  src, sptypes), Union)
                    return params.inline_nonleaf_penalty
                end
            end
            fidx = find_tfunc(f)
            if fidx === nothing
                # unknown/unhandled builtin
                # Use the generic cost of a direct function call
                return 20
            end
            return T_FFUNC_COST[fidx]
        end
        extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes)
        if extyp === Union{}
            return 0
        end
        return error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
    elseif head === :foreigncall || head === :invoke || head == :invoke_modify
        # Calls whose "return type" is Union{} do not actually return:
        # they are errors. Since these are not part of the typical
        # run-time of the function, we omit them from
        # consideration. This way, non-inlined error branches do not
        # prevent inlining.
        extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes)
        return extyp === Union{} ? 0 : 20
    elseif head === :(=)
        if ex.args[1] isa GlobalRef
            cost = 20
        else
            cost = 0
        end
        a = ex.args[2]
        if a isa Expr
            cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, union_penalties, params, error_path))
        end
        return cost
    elseif head === :copyast
        return 100
    elseif head === :enter
        # try/catch is a couple function calls,
        # but don't inline functions with try/catch
        # since these aren't usually performance-sensitive functions,
        # and llvm is more likely to miscompile them when these functions get large
        return typemax(Int)
    end
    return 0
end

function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
                                  union_penalties::Bool, params::OptimizationParams)
    thiscost = 0
    dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
    if stmt isa Expr
        thiscost = statement_cost(stmt, line, src, sptypes, union_penalties, params,
                                  is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int
    elseif stmt isa GotoNode
        # loops are generally always expensive
        # but assume that forward jumps are already counted for from
        # summing the cost of the not-taken branch
        thiscost = dst(stmt.label) < line ? 40 : 0
    elseif stmt isa GotoIfNot
        thiscost = dst(stmt.dest) < line ? 40 : 0
    end
    return thiscost
end

function inline_worthy(ir::IRCode,
                       params::OptimizationParams, union_penalties::Bool=false, cost_threshold::Integer=params.inline_cost_threshold)
    bodycost::Int = 0
    for line = 1:length(ir.stmts)
        stmt = ir.stmts[line][:inst]
        thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params)
        bodycost = plus_saturate(bodycost, thiscost)
        bodycost > cost_threshold && return false
    end
    return true
end

function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any}, unionpenalties::Bool, params::OptimizationParams)
    maxcost = 0
    for line = 1:length(body)
        stmt = body[line]
        thiscost = statement_or_branch_cost(stmt, line, src, sptypes,
                                            unionpenalties, params)
        cost[line] = thiscost
        if thiscost > maxcost
            maxcost = thiscost
        end
    end
    return maxcost
end

function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int})
    return renumber_ir_elements!(body, ssachangemap, ssachangemap)
end

function cumsum_ssamap!(ssachangemap::Vector{Int})
    any_change = false
    rel_change = 0
    for i = 1:length(ssachangemap)
        val = ssachangemap[i]
        any_change |= val ≠ 0
        rel_change += val
        if val == -1
            # Keep a marker that this statement was deleted
            ssachangemap[i] = typemin(Int)
        else
            ssachangemap[i] = rel_change
        end
    end
    return any_change
end

function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, labelchangemap::Vector{Int})
    any_change = cumsum_ssamap!(labelchangemap)
    if ssachangemap !== labelchangemap
        any_change |= cumsum_ssamap!(ssachangemap)
    end
    any_change || return
    for i = 1:length(body)
        el = body[i]
        if isa(el, GotoNode)
            body[i] = GotoNode(el.label + labelchangemap[el.label])
        elseif isa(el, GotoIfNot)
            cond = el.cond
            if isa(cond, SSAValue)
                cond = SSAValue(cond.id + ssachangemap[cond.id])
            end
            was_deleted = labelchangemap[el.dest] == typemin(Int)
            body[i] = was_deleted ? cond : GotoIfNot(cond, el.dest + labelchangemap[el.dest])
        elseif isa(el, ReturnNode)
            if isdefined(el, :val)
                val = el.val
                if isa(val, SSAValue)
                    body[i] = ReturnNode(SSAValue(val.id + ssachangemap[val.id]))
                end
            end
        elseif isa(el, SSAValue)
            body[i] = SSAValue(el.id + ssachangemap[el.id])
        elseif isa(el, PhiNode)
            i = 1
            edges = el.edges
            values = el.values
            while i <= length(edges)
                was_deleted = ssachangemap[edges[i]] == typemin(Int)
                if was_deleted
                    deleteat!(edges, i)
                    deleteat!(values, i)
                else
                    edges[i] += ssachangemap[edges[i]]
                    val = values[i]
                    if isa(val, SSAValue)
                        values[i] = SSAValue(val.id + ssachangemap[val.id])
                    end
                    i += 1
                end
            end
        elseif isa(el, Expr)
            if el.head === :(=) && el.args[2] isa Expr
                el = el.args[2]::Expr
            end
            if el.head === :enter
                tgt = el.args[1]::Int
                el.args[1] = tgt + labelchangemap[tgt]
            elseif !is_meta_expr_head(el.head)
                args = el.args
                for i = 1:length(args)
                    el = args[i]
                    if isa(el, SSAValue)
                        args[i] = SSAValue(el.id + ssachangemap[el.id])
                    end
                end
            end
        end
    end
end
back to top