Skip to content

Commit

Permalink
compiler: general refactor (#41633)
Browse files Browse the repository at this point in the history
Separated from compiler-plugin prototyping.

cherry-picked from 799136d
  • Loading branch information
aviatesk committed Sep 9, 2021
1 parent c85012a commit 1232010
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 119 deletions.
192 changes: 110 additions & 82 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,73 +35,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
add_remark!(interp, sv, "Skipped call in throw block")
return CallMeta(Any, false)
end
valid_worlds = WorldRange()
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
splitunions = 1 < unionsplitcost(argtypes) <= InferenceParams(interp).MAX_UNION_SPLITTING
mts = Core.MethodTable[]
fullmatch = Bool[]
if splitunions
split_argtypes = switchtupleunion(argtypes)
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
infos = MethodMatchInfo[]
for arg_n in split_argtypes
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
if mt === nothing
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
end
mt = mt::Core.MethodTable
matches = findall(sig_n, method_table(interp); limit=max_methods)
if matches === missing
add_remark!(interp, sv, "For one of the union split cases, too many methods matched")
return CallMeta(Any, false)
end
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatch[i] &= thisfullmatch
found = true
break
end
end
if !found
push!(mts, mt)
push!(fullmatch, thisfullmatch)
end
end
info = UnionSplitInfo(infos)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
end
mt = mt::Core.MethodTable
matches = findall(atype, method_table(interp, sv); limit=max_methods)
if matches === missing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
add_remark!(interp, sv, "Too many methods matched")
return CallMeta(Any, false)
end
push!(mts, mt)
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches))
info = MethodMatchInfo(matches)
applicable = matches.matches
valid_worlds = matches.valid_worlds
applicable_argtypes = nothing

matches = find_matching_methods(argtypes, atype, method_table(interp, sv), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
return CallMeta(Any, false)
end

(; valid_worlds, applicable, info) = matches
update_valid_age!(sv, valid_worlds)
applicable = applicable::Array{Any,1}
napplicable = length(applicable)
rettype = Bottom
edges = MethodInstance[]
Expand Down Expand Up @@ -142,7 +84,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if edge !== nothing
push!(edges, edge)
end
this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i]
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
if const_rt !== rt && const_rt rt
rt = const_rt
Expand All @@ -164,7 +106,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i]
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
if const_this_rt !== this_rt && const_this_rt this_rt
this_rt = const_this_rt
Expand Down Expand Up @@ -272,7 +214,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
add_call_backedges!(interp, rettype, edges, matches, atype, sv)
if !isempty(sv.pclimitations) # remove self, if present
delete!(sv.pclimitations, sv)
for caller in sv.callers_in_cycle
Expand All @@ -283,24 +225,110 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(rettype, info)
end

function add_call_backedges!(interp::AbstractInterpreter,
@nospecialize(rettype),
edges::Vector{MethodInstance},
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
sv::InferenceState)
if rettype === Any
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
# (widen) this type
return
struct FailedMethodMatch
reason::String
end

struct MethodMatches
applicable::Vector{Any}
info::MethodMatchInfo
valid_worlds::WorldRange
mt::Core.MethodTable
fullmatch::Bool
end

struct UnionSplitMethodMatches
applicable::Vector{Any}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
mts::Vector{Core.MethodTable}
fullmatches::Vector{Bool}
end

function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
union_split::Int, max_methods::Int)
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
if 1 < unionsplitcost(argtypes) <= union_split
split_argtypes = switchtupleunion(argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
mts = Core.MethodTable[]
fullmatches = Bool[]
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::Core.MethodTable
matches = findall(sig_n, method_table; limit = max_methods)
if matches === missing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
found = true
break
end
end
if !found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
end
return UnionSplitMethodMatches(applicable,
applicable_argtypes,
UnionSplitInfo(infos),
valid_worlds,
mts,
fullmatches)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
mt = mt::Core.MethodTable
matches = findall(atype, method_table; limit = max_methods)
if matches === missing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
matches.valid_worlds,
mt,
fullmatch)
end
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), edges::Vector{MethodInstance},
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
sv::InferenceState)
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine (widen) this type
rettype === Any && return
for edge in edges
add_backedge!(edge, sv)
end
for (thisfullmatch, mt) in zip(fullmatch, mts)
if !thisfullmatch
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge!(mt, atype, sv)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
thisfullmatch || add_mt_backedge!(mt, atype, sv)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
# so that we can construct cache-correct `InferenceResult`s in the first place.
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override)
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool)
@assert isa(linfo.def, Method) # ensure the next line works
nargs::Int = linfo.def.nargs
@assert length(given_argtypes) >= (nargs - 1)
Expand Down
25 changes: 13 additions & 12 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,11 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
return true
end

# Convert IRCode back to CodeInfo and compute inlining cost and sideeffects
# compute inlining cost and sideeffects
function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result))
def = opt.linfo.def
nargs = Int(opt.nargs) - 1
(; src, nargs, linfo) = opt
(; def, specTypes) = linfo
nargs = Int(nargs) - 1

force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)

Expand All @@ -221,7 +222,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
end
end
if proven_pure
for fl in opt.src.slotflags
for fl in src.slotflags
if (fl & SLOT_USEDUNDEF) != 0
proven_pure = false
break
Expand All @@ -230,7 +231,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
end
end
if proven_pure
opt.src.pure = true
src.pure = true
end

if proven_pure
Expand All @@ -243,7 +244,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
if !(isa(result, Const) && !is_inlineable_constant(result.val))
opt.const_api = true
end
force_noinline || (opt.src.inlineable = true)
force_noinline || (src.inlineable = true)
end
end

Expand All @@ -252,7 +253,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
# determine and cache inlineability
union_penalties = false
if !force_noinline
sig = unwrap_unionall(opt.linfo.specTypes)
sig = unwrap_unionall(specTypes)
if isa(sig, DataType) && sig.name === Tuple.name
for P in sig.parameters
P = unwrap_unionall(P)
Expand All @@ -264,25 +265,25 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
else
force_noinline = true
end
if !opt.src.inlineable && result === Union{}
if !src.inlineable && result === Union{}
force_noinline = true
end
end
if force_noinline
opt.src.inlineable = false
src.inlineable = false
elseif isa(def, Method)
if opt.src.inlineable && isdispatchtuple(opt.linfo.specTypes)
if src.inlineable && isdispatchtuple(specTypes)
# obey @inline declaration if a dispatch barrier would not help
else
bonus = 0
if result Tuple && !isconcretetype(widenconst(result))
bonus = params.inline_tupleret_bonus
end
if opt.src.inlineable
if src.inlineable
# For functions declared @inline, increase the cost threshold 20x
bonus += params.inline_cost_threshold*19
end
opt.src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
end
end

Expand Down
Loading

0 comments on commit 1232010

Please sign in to comment.