Skip to content

Commit

Permalink
Merge pull request #2030 from ruby/optional-type-param-limit
Browse files Browse the repository at this point in the history
Optional type param restriction
  • Loading branch information
soutaro committed Sep 27, 2024
2 parents abd609b + 8c18b9c commit 4e5ce53
Show file tree
Hide file tree
Showing 12 changed files with 238 additions and 44 deletions.
55 changes: 44 additions & 11 deletions lib/rbs/ast/type_param.rb
Original file line number Diff line number Diff line change
Expand Up @@ -154,39 +154,72 @@ def to_s
end

def self.application(params, args)
subst = Substitution.new()

if params.empty?
return nil
end

min_count = params.count { _1.default_type.nil? }
max_count = params.size
optional_params, required_params = params.partition {|param| param.default_type }

param_subst = Substitution.new()
app_subst = Substitution.new()

required_params.zip(args.take(required_params.size)).each do |param, arg|
arg ||= Types::Bases::Any.new(location: nil)
param_subst.add(from: param.name, to: arg)
app_subst.add(from: param.name, to: arg)
end

unless min_count <= args.size && args.size <= max_count
raise "Invalid type application: required type params=#{min_count}, optional type params=#{max_count - min_count}, given args=#{args.size}"
optional_params.each do |param|
param_subst.add(from: param.name, to: Types::Bases::Any.new(location: nil))
end

params.zip(args).each do |param, arg|
optional_params.zip(args.drop(required_params.size)).each do |param, arg|
if arg
subst.add(from: param.name, to: arg)
app_subst.add(from: param.name, to: arg)
else
subst.add(from: param.name, to: param.default_type || raise)
param.default_type or raise
app_subst.add(from: param.name, to: param.default_type.sub(param_subst))
end
end

subst
app_subst
end

def self.normalize_args(params, args)
app = application(params, args) or return []

min_count = params.count { _1.default_type.nil? }
unless min_count <= args.size && args.size <= params.size
return args
end

params.zip(args).filter_map do |param, arg|
if arg
arg
else
param.default_type
if param.default_type
param.default_type.sub(app)
else
Types::Bases::Any.new(location: nil)
end
end
end
end

def self.validate(type_params)
optionals = type_params.filter {|param| param.default_type }

optional_param_names = optionals.map(&:name).sort

optionals.filter! do |param|
default_type = param.default_type or raise
optional_param_names.any? { default_type.free_variables.include?(_1) }
end

unless optionals.empty?
optionals
end
end
end
end
end
10 changes: 6 additions & 4 deletions lib/rbs/cli/validate.rb
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,6 @@ def validate_class_module_definition
no_classish_type_validator(arg)
@validator.validate_type(arg, context: nil)
end

if super_entry = @env.normalized_class_entry(super_class.name)
InvalidTypeApplicationError.check!(type_name: super_class.name, args: super_class.args, params: super_entry.type_params, location: super_class.location)
end
end
end
when Environment::ModuleEntry
Expand Down Expand Up @@ -171,6 +167,8 @@ def validate_class_module_definition
end
end

TypeParamDefaultReferenceError.check!(d.type_params)

entry.decls.each do |d|
d.decl.each_member do |member|
case member
Expand Down Expand Up @@ -248,6 +246,8 @@ def validate_interface
end
end

TypeParamDefaultReferenceError.check!(decl.decl.type_params)

decl.decl.members.each do |member|
case member
when AST::Members::MethodDefinition
Expand Down Expand Up @@ -319,6 +319,8 @@ def validate_type_alias
end
end

TypeParamDefaultReferenceError.check!(decl.decl.type_params)

no_self_type_validator(decl.decl.type)
no_classish_type_validator(decl.decl.type)
void_type_context_validator(decl.decl.type)
Expand Down
10 changes: 2 additions & 8 deletions lib/rbs/definition.rb
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,8 @@ def initialize(type_name:, params:, ancestors:)
@ancestors = ancestors
end

def apply(args, location:)
# Assume default types of type parameters are already added to `args`
InvalidTypeApplicationError.check!(
type_name: type_name,
args: args,
params: params.map { AST::TypeParam.new(name: _1, variance: :invariant, upper_bound: nil, location: nil, default_type: nil) },
location: location
)
def apply(args, env:, location:)
InvalidTypeApplicationError.check2!(env: env, type_name: type_name, args: args, location: location)

subst = Substitution.build(params, args)

Expand Down
13 changes: 7 additions & 6 deletions lib/rbs/definition_builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def tapp_subst(name, args)
raise
end

Substitution.build(params.map(&:name), args)
AST::TypeParam.application(params, args) || Substitution.new()
end

def define_instance(definition, type_name, subst)
Expand Down Expand Up @@ -549,16 +549,17 @@ def import_methods(definition, module_name, module_methods, interfaces_methods,
interfaces_methods.each do |interface, (methods, member)|
unless interface.args.empty?
methods.type.is_a?(Types::Interface) or raise
params = methods.type.args.map do |arg|
arg.is_a?(Types::Variable) or raise
arg.name
end

interface.args.each do |arg|
validate_type_presence(arg)
end

subst_ = subst + Substitution.build(params, interface.args)
type_params = env.interface_decls.fetch(interface.name).decl.type_params
if s = AST::TypeParam.application(type_params, interface.args)
subst_ = subst + s
else
subst_ = subst
end
else
subst_ = subst
end
Expand Down
30 changes: 23 additions & 7 deletions lib/rbs/definition_builder/ancestor_builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def one_instance_ancestors(type_name)
end

super_name = env.normalize_module_name(super_name)

NoSuperclassFoundError.check!(super_name, env: env, location: primary.decl.location)
if super_class
InheritModuleError.check!(super_class, env: env)
Expand Down Expand Up @@ -458,7 +458,9 @@ def instance_ancestors(type_name, building_ancestors: [])
super_name = super_class.name
super_args = super_class.args

super_ancestors = instance_ancestors(super_name, building_ancestors: building_ancestors).apply(super_args, location: entry.primary.decl.location)
super_ancestors =
instance_ancestors(super_name, building_ancestors: building_ancestors)
.apply(super_args, env: env, location: entry.primary.decl.super_class&.location)
super_ancestors.map! {|ancestor| fill_ancestor_source(ancestor, name: super_name, source: :super) }
ancestors.unshift(*super_ancestors)
end
Expand All @@ -477,7 +479,10 @@ def instance_ancestors(type_name, building_ancestors: [])
included_modules.each do |mod|
name = mod.name
arg_types = mod.args
mod_ancestors = instance_ancestors(name, building_ancestors: building_ancestors).apply(arg_types, location: entry.primary.decl.location)
mod.source.is_a?(AST::Members::Include) or raise
mod_ancestors =
instance_ancestors(name, building_ancestors: building_ancestors)
.apply(arg_types, env: env, location: mod.source.location)
mod_ancestors.map! {|ancestor| fill_ancestor_source(ancestor, name: name, source: mod.source) }
ancestors.unshift(*mod_ancestors)
end
Expand All @@ -489,7 +494,10 @@ def instance_ancestors(type_name, building_ancestors: [])
prepended_modules.each do |mod|
name = mod.name
arg_types = mod.args
mod_ancestors = instance_ancestors(name, building_ancestors: building_ancestors).apply(arg_types, location: entry.primary.decl.location)
mod.source.is_a?(AST::Members::Prepend) or raise
mod_ancestors =
instance_ancestors(name, building_ancestors: building_ancestors)
.apply(arg_types, env: env, location: mod.source.location)
mod_ancestors.map! {|ancestor| fill_ancestor_source(ancestor, name: name, source: mod.source) }
ancestors.unshift(*mod_ancestors)
end
Expand Down Expand Up @@ -524,7 +532,9 @@ def singleton_ancestors(type_name, building_ancestors: [])
super_name = super_class.name
super_args = super_class.args

super_ancestors = instance_ancestors(super_name, building_ancestors: building_ancestors).apply(super_args, location: entry.primary.decl.location)
super_ancestors =
instance_ancestors(super_name, building_ancestors: building_ancestors)
.apply(super_args, env: env, location: nil)
super_ancestors.map! {|ancestor| fill_ancestor_source(ancestor, name: super_name, source: :super) }
ancestors.unshift(*super_ancestors)

Expand All @@ -539,7 +549,10 @@ def singleton_ancestors(type_name, building_ancestors: [])
extended_modules.each do |mod|
name = mod.name
args = mod.args
mod_ancestors = instance_ancestors(name, building_ancestors: building_ancestors).apply(args, location: entry.primary.decl.location)
mod.source.is_a?(AST::Members::Extend) or raise
mod_ancestors =
instance_ancestors(name, building_ancestors: building_ancestors)
.apply(args, env: env, location: mod.source.location)
mod_ancestors.map! {|ancestor| fill_ancestor_source(ancestor, name: name, source: mod.source) }
ancestors.unshift(*mod_ancestors)
end
Expand Down Expand Up @@ -572,7 +585,10 @@ def interface_ancestors(type_name, building_ancestors: [])

included_interfaces = one_ancestors.included_interfaces or raise
included_interfaces.each do |a|
included_ancestors = interface_ancestors(a.name, building_ancestors: building_ancestors).apply(a.args, location: entry.decl.location)
a.source.is_a?(AST::Members::Include) or raise
included_ancestors =
interface_ancestors(a.name, building_ancestors: building_ancestors)
.apply(a.args, env: env, location: a.source.location)
included_ancestors.map! {|ancestor| fill_ancestor_source(ancestor, name: a.name, source: a.source) }
ancestors.unshift(*included_ancestors)
end
Expand Down
36 changes: 36 additions & 0 deletions lib/rbs/errors.rb
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ def self.check!(type_name:, args:, params:, location:)
raise new(type_name: type_name, args: args, params: params, location: location)
end
end

def self.check2!(env:, type_name:, args:, location:)
params =
case
when type_name.class?
decl = env.normalized_module_class_entry(type_name) or raise
decl.type_params
when type_name.interface?
env.interface_decls.fetch(type_name).decl.type_params
when type_name.alias?
env.type_alias_decls.fetch(type_name).decl.type_params
else
raise
end

check!(type_name: type_name, args: args, params: params, location: location)
end
end

class RecursiveAncestorError < DefinitionError
Expand Down Expand Up @@ -559,4 +576,23 @@ def initialize(message, location:)
@location = location
end
end

class TypeParamDefaultReferenceError < DefinitionError
include DetailedMessageable

attr_reader :location

def initialize(type_param, location:)
super "#{Location.to_string(location)}: the default of #{type_param.name} cannot include optional type parameter"
@location = location
end

def self.check!(type_params)
if errors = AST::TypeParam.validate(type_params)
error = errors[0] or raise
error.default_type or raise
raise new(error, location: error.default_type.location)
end
end
end
end
2 changes: 1 addition & 1 deletion sig/definition.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ module RBS

def initialize: (type_name: TypeName, params: Array[Symbol], ancestors: Array[Ancestor::t]) -> void

def apply: (Array[Types::t], location: Location[untyped, untyped]?) -> Array[Ancestor::t]
def apply: (Array[Types::t], env: Environment, location: Location[untyped, untyped]?) -> Array[Ancestor::t]
end

class SingletonAncestors
Expand Down
12 changes: 12 additions & 0 deletions sig/errors.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ module RBS
def initialize: (type_name: TypeName, args: Array[Types::t], params: Array[AST::TypeParam], location: Location[untyped, untyped]?) -> void

def self.check!: (type_name: TypeName, args: Array[Types::t], params: Array[AST::TypeParam], location: Location[untyped, untyped]?) -> void

def self.check2!: (env: Environment, type_name: TypeName, args: Array[Types::t], location: Location[untyped, untyped]?) -> void
end

class RecursiveAncestorError < DefinitionError
Expand Down Expand Up @@ -366,4 +368,14 @@ module RBS

attr_reader location: Location[untyped, untyped]?
end

class TypeParamDefaultReferenceError < BaseError
include DetailedMessageable

def initialize: (AST::TypeParam, location: Location[untyped, untyped]?) -> void

attr_reader location: Location[untyped, untyped]?

def self.check!: (Array[AST::TypeParam]) -> void
end
end
3 changes: 1 addition & 2 deletions sig/members.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ module RBS
# include Array[String]
# ^^^^^^^ keyword
# ^^^^^ name
# ^ arg_open
# ^ arg_close
# ^^^^^^^^ args
#
type loc = Location[:name | :keyword, :args]

Expand Down
14 changes: 9 additions & 5 deletions sig/type_param.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ module RBS

def to_s: () -> String

# Returns an application with respect to type params` default
# Validates TypeParams if it refers another optiional type params
#
# * Returns array of TypeParam objects that refers other optional type params
# * Returns `nil` if all type params are valid
#
def self.validate: (Array[TypeParam]) -> Array[TypeParam]?

# Returns an application with respect to type params' default
#
def self.application: (Array[TypeParam], Array[Types::t]) -> Substitution?

Expand All @@ -95,10 +102,7 @@ module RBS
# _Foo[String, Integer, untyped] # => _Foo[String, Integer, untyped] (Keeping extra args)
# ```
#
# Note that it allows iinvalid arities.
#
# * Missing args will be omitted
# * Extra args will be keeped
# Note that it allows iinvalid arities, returning the `args` immediately.
#
def self.normalize_args: (Array[TypeParam], Array[Types::t]) -> Array[Types::t]
end
Expand Down
Loading

0 comments on commit 4e5ce53

Please sign in to comment.