aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/irb.rb4
-rw-r--r--lib/irb/cmd/irb_info.rb1
-rw-r--r--lib/irb/completion.rb49
-rw-r--r--lib/irb/context.rb41
-rw-r--r--lib/irb/init.rb1
-rw-r--r--lib/irb/input-method.rb15
-rw-r--r--lib/irb/type_completion/completor.rb235
-rw-r--r--lib/irb/type_completion/methods.rb13
-rw-r--r--lib/irb/type_completion/scope.rb412
-rw-r--r--lib/irb/type_completion/type_analyzer.rb1169
-rw-r--r--lib/irb/type_completion/types.rb426
-rw-r--r--test/irb/test_cmd.rb5
-rw-r--r--test/irb/test_context.rb18
-rw-r--r--test/irb/test_input_method.rb10
-rw-r--r--test/irb/type_completion/test_scope.rb112
-rw-r--r--test/irb/type_completion/test_type_analyze.rb697
-rw-r--r--test/irb/type_completion/test_type_completor.rb181
-rw-r--r--test/irb/type_completion/test_types.rb89
18 files changed, 3446 insertions, 32 deletions
diff --git a/lib/irb.rb b/lib/irb.rb
index d0688e6f9f..655abaf069 100644
--- a/lib/irb.rb
+++ b/lib/irb.rb
@@ -140,6 +140,10 @@ require_relative "irb/debug"
#
# IRB.conf[:USE_AUTOCOMPLETE] = false
#
+# To enable enhanced completion using type information, add the following to your +.irbrc+:
+#
+# IRB.conf[:COMPLETOR] = :type
+#
# === History
#
# By default, irb will store the last 1000 commands you used in
diff --git a/lib/irb/cmd/irb_info.rb b/lib/irb/cmd/irb_info.rb
index 75fdc38676..5b905a09bd 100644
--- a/lib/irb/cmd/irb_info.rb
+++ b/lib/irb/cmd/irb_info.rb
@@ -14,6 +14,7 @@ module IRB
str = "Ruby version: #{RUBY_VERSION}\n"
str += "IRB version: #{IRB.version}\n"
str += "InputMethod: #{IRB.CurrentContext.io.inspect}\n"
+ str += "Completion: #{IRB.CurrentContext.io.respond_to?(:completion_info) ? IRB.CurrentContext.io.completion_info : 'off'}\n"
str += ".irbrc path: #{IRB.rc_file}\n" if File.exist?(IRB.rc_file)
str += "RUBY_PLATFORM: #{RUBY_PLATFORM}\n"
str += "LANG env: #{ENV["LANG"]}\n" if ENV["LANG"] && !ENV["LANG"].empty?
diff --git a/lib/irb/completion.rb b/lib/irb/completion.rb
index 61bdc33587..e3ebe4abff 100644
--- a/lib/irb/completion.rb
+++ b/lib/irb/completion.rb
@@ -9,6 +9,30 @@ require_relative 'ruby-lex'
module IRB
class BaseCompletor # :nodoc:
+
+ # Set of reserved words used by Ruby, you should not use these for
+ # constants or variables
+ ReservedWords = %w[
+ __ENCODING__ __LINE__ __FILE__
+ BEGIN END
+ alias and
+ begin break
+ case class
+ def defined? do
+ else elsif end ensure
+ false for
+ if in
+ module
+ next nil not
+ or
+ redo rescue retry return
+ self super
+ then true
+ undef unless until
+ when while
+ yield
+ ]
+
def completion_candidates(preposing, target, postposing, bind:)
raise NotImplementedError
end
@@ -94,28 +118,9 @@ module IRB
end
}
- # Set of reserved words used by Ruby, you should not use these for
- # constants or variables
- ReservedWords = %w[
- __ENCODING__ __LINE__ __FILE__
- BEGIN END
- alias and
- begin break
- case class
- def defined? do
- else elsif end ensure
- false for
- if in
- module
- next nil not
- or
- redo rescue retry return
- self super
- then true
- undef unless until
- when while
- yield
- ]
+ def inspect
+ 'RegexpCompletor'
+ end
def complete_require_path(target, preposing, postposing)
if target =~ /\A(['"])([^'"]+)\Z/
diff --git a/lib/irb/context.rb b/lib/irb/context.rb
index a20510d73c..5dfe9d0d71 100644
--- a/lib/irb/context.rb
+++ b/lib/irb/context.rb
@@ -86,14 +86,14 @@ module IRB
when nil
if STDIN.tty? && IRB.conf[:PROMPT_MODE] != :INF_RUBY && !use_singleline?
# Both of multiline mode and singleline mode aren't specified.
- @io = RelineInputMethod.new
+ @io = RelineInputMethod.new(build_completor)
else
@io = nil
end
when false
@io = nil
when true
- @io = RelineInputMethod.new
+ @io = RelineInputMethod.new(build_completor)
end
unless @io
case use_singleline?
@@ -149,6 +149,43 @@ module IRB
@command_aliases = IRB.conf[:COMMAND_ALIASES]
end
+ private def build_completor
+ completor_type = IRB.conf[:COMPLETOR]
+ case completor_type
+ when :regexp
+ return RegexpCompletor.new
+ when :type
+ completor = build_type_completor
+ return completor if completor
+ else
+ warn "Invalid value for IRB.conf[:COMPLETOR]: #{completor_type}"
+ end
+ # Fallback to RegexpCompletor
+ RegexpCompletor.new
+ end
+
+ TYPE_COMPLETION_REQUIRED_PRISM_VERSION = '0.17.1'
+
+ private def build_type_completor
+ unless Gem::Version.new(RUBY_VERSION) >= Gem::Version.new('3.0.0') && RUBY_ENGINE != 'truffleruby'
+ warn 'TypeCompletion requires RUBY_VERSION >= 3.0.0'
+ return
+ end
+ begin
+ require 'prism'
+ rescue LoadError => e
+ warn "TypeCompletion requires Prism: #{e.message}"
+ return
+ end
+ unless Gem::Version.new(Prism::VERSION) >= Gem::Version.new(TYPE_COMPLETION_REQUIRED_PRISM_VERSION)
+ warn "TypeCompletion requires Prism::VERSION >= #{TYPE_COMPLETION_REQUIRED_PRISM_VERSION}"
+ return
+ end
+ require 'irb/type_completion/completor'
+ TypeCompletion::Types.preload_in_thread
+ TypeCompletion::Completor.new
+ end
+
def save_history=(val)
IRB.conf[:SAVE_HISTORY] = val
end
diff --git a/lib/irb/init.rb b/lib/irb/init.rb
index d9549420b4..e9111974f0 100644
--- a/lib/irb/init.rb
+++ b/lib/irb/init.rb
@@ -76,6 +76,7 @@ module IRB # :nodoc:
@CONF[:USE_SINGLELINE] = false unless defined?(ReadlineInputMethod)
@CONF[:USE_COLORIZE] = (nc = ENV['NO_COLOR']).nil? || nc.empty?
@CONF[:USE_AUTOCOMPLETE] = ENV.fetch("IRB_USE_AUTOCOMPLETE", "true") != "false"
+ @CONF[:COMPLETOR] = :regexp
@CONF[:INSPECT_MODE] = true
@CONF[:USE_TRACER] = false
@CONF[:USE_LOADER] = false
diff --git a/lib/irb/input-method.rb b/lib/irb/input-method.rb
index cef65b7162..94ad28cd63 100644
--- a/lib/irb/input-method.rb
+++ b/lib/irb/input-method.rb
@@ -193,6 +193,10 @@ module IRB
}
end
+ def completion_info
+ 'RegexpCompletor'
+ end
+
# Reads the next line from this input method.
#
# See IO#gets for more information.
@@ -230,13 +234,13 @@ module IRB
HISTORY = Reline::HISTORY
include HistorySavingAbility
# Creates a new input method object using Reline
- def initialize
+ def initialize(completor)
IRB.__send__(:set_encoding, Reline.encoding_system_needs.name, override: false)
- super
+ super()
@eof = false
- @completor = RegexpCompletor.new
+ @completor = completor
Reline.basic_word_break_characters = BASIC_WORD_BREAK_CHARACTERS
Reline.completion_append_character = nil
@@ -270,6 +274,11 @@ module IRB
end
end
+ def completion_info
+ autocomplete_message = Reline.autocompletion ? 'Autocomplete' : 'Tab Complete'
+ "#{autocomplete_message}, #{@completor.inspect}"
+ end
+
def check_termination(&block)
@check_termination_proc = block
end
diff --git a/lib/irb/type_completion/completor.rb b/lib/irb/type_completion/completor.rb
new file mode 100644
index 0000000000..e893fd8adc
--- /dev/null
+++ b/lib/irb/type_completion/completor.rb
@@ -0,0 +1,235 @@
+# frozen_string_literal: true
+
+require 'prism'
+require 'irb/completion'
+require_relative 'type_analyzer'
+
+module IRB
+ module TypeCompletion
+ class Completor < BaseCompletor # :nodoc:
+ HIDDEN_METHODS = %w[Namespace TypeName] # defined by rbs, should be hidden
+
+ class << self
+ attr_accessor :last_completion_error
+ end
+
+ def inspect
+ name = 'TypeCompletion::Completor'
+ prism_info = "Prism: #{Prism::VERSION}"
+ if Types.rbs_builder
+ "#{name}(#{prism_info}, RBS: #{RBS::VERSION})"
+ elsif Types.rbs_load_error
+ "#{name}(#{prism_info}, RBS: #{Types.rbs_load_error.inspect})"
+ else
+ "#{name}(#{prism_info}, RBS: loading)"
+ end
+ end
+
+ def completion_candidates(preposing, target, _postposing, bind:)
+ @preposing = preposing
+ verbose, $VERBOSE = $VERBOSE, nil
+ code = "#{preposing}#{target}"
+ @result = analyze code, bind
+ name, candidates = candidates_from_result(@result)
+
+ all_symbols_pattern = /\A[ -\/:-@\[-`\{-~]*\z/
+ candidates.map(&:to_s).select { !_1.match?(all_symbols_pattern) && _1.start_with?(name) }.uniq.sort.map do
+ target + _1[name.size..]
+ end
+ rescue SyntaxError, StandardError => e
+ Completor.last_completion_error = e
+ handle_error(e)
+ []
+ ensure
+ $VERBOSE = verbose
+ end
+
+ def doc_namespace(preposing, matched, postposing, bind:)
+ name = matched[/[a-zA-Z_0-9]*[!?=]?\z/]
+ method_doc = -> type do
+ type = type.types.find { _1.all_methods.include? name.to_sym }
+ case type
+ when Types::SingletonType
+ "#{Types.class_name_of(type.module_or_class)}.#{name}"
+ when Types::InstanceType
+ "#{Types.class_name_of(type.klass)}##{name}"
+ end
+ end
+ call_or_const_doc = -> type do
+ if name =~ /\A[A-Z]/
+ type = type.types.grep(Types::SingletonType).find { _1.module_or_class.const_defined?(name) }
+ type.module_or_class == Object ? name : "#{Types.class_name_of(type.module_or_class)}::#{name}" if type
+ else
+ method_doc.call(type)
+ end
+ end
+
+ value_doc = -> type do
+ return unless type
+ type.types.each do |t|
+ case t
+ when Types::SingletonType
+ return Types.class_name_of(t.module_or_class)
+ when Types::InstanceType
+ return Types.class_name_of(t.klass)
+ end
+ end
+ nil
+ end
+
+ case @result
+ in [:call_or_const, type, _name, _self_call]
+ call_or_const_doc.call type
+ in [:const, type, _name, scope]
+ if type
+ call_or_const_doc.call type
+ else
+ value_doc.call scope[name]
+ end
+ in [:gvar, _name, scope]
+ value_doc.call scope["$#{name}"]
+ in [:ivar, _name, scope]
+ value_doc.call scope["@#{name}"]
+ in [:cvar, _name, scope]
+ value_doc.call scope["@@#{name}"]
+ in [:call, type, _name, _self_call]
+ method_doc.call type
+ in [:lvar_or_method, _name, scope]
+ if scope.local_variables.include?(name)
+ value_doc.call scope[name]
+ else
+ method_doc.call scope.self_type
+ end
+ else
+ end
+ end
+
+ def candidates_from_result(result)
+ candidates = case result
+ in [:require, name]
+ retrieve_files_to_require_from_load_path
+ in [:require_relative, name]
+ retrieve_files_to_require_relative_from_current_dir
+ in [:call_or_const, type, name, self_call]
+ ((self_call ? type.all_methods : type.methods).map(&:to_s) - HIDDEN_METHODS) | type.constants
+ in [:const, type, name, scope]
+ if type
+ scope_constants = type.types.flat_map do |t|
+ scope.table_module_constants(t.module_or_class) if t.is_a?(Types::SingletonType)
+ end
+ (scope_constants.compact | type.constants.map(&:to_s)).sort
+ else
+ scope.constants.sort | ReservedWords
+ end
+ in [:ivar, name, scope]
+ ivars = scope.instance_variables.sort
+ name == '@' ? ivars + scope.class_variables.sort : ivars
+ in [:cvar, name, scope]
+ scope.class_variables
+ in [:gvar, name, scope]
+ scope.global_variables
+ in [:symbol, name]
+ Symbol.all_symbols.map { _1.inspect[1..] }
+ in [:call, type, name, self_call]
+ (self_call ? type.all_methods : type.methods).map(&:to_s) - HIDDEN_METHODS
+ in [:lvar_or_method, name, scope]
+ scope.self_type.all_methods.map(&:to_s) | scope.local_variables | ReservedWords
+ else
+ []
+ end
+ [name || '', candidates]
+ end
+
+ def analyze(code, binding = Object::TOPLEVEL_BINDING)
+ # Workaround for https://github.com/ruby/prism/issues/1592
+ return if code.match?(/%[qQ]\z/)
+
+ ast = Prism.parse(code, scopes: [binding.local_variables]).value
+ name = code[/(@@|@|\$)?\w*[!?=]?\z/]
+ *parents, target_node = find_target ast, code.bytesize - name.bytesize
+ return unless target_node
+
+ calculate_scope = -> { TypeAnalyzer.calculate_target_type_scope(binding, parents, target_node).last }
+ calculate_type_scope = ->(node) { TypeAnalyzer.calculate_target_type_scope binding, [*parents, target_node], node }
+
+ case target_node
+ when Prism::StringNode, Prism::InterpolatedStringNode
+ call_node, args_node = parents.last(2)
+ return unless call_node.is_a?(Prism::CallNode) && call_node.receiver.nil?
+ return unless args_node.is_a?(Prism::ArgumentsNode) && args_node.arguments.size == 1
+
+ case call_node.name
+ when :require
+ [:require, name.rstrip]
+ when :require_relative
+ [:require_relative, name.rstrip]
+ end
+ when Prism::SymbolNode
+ if parents.last.is_a? Prism::BlockArgumentNode # method(&:target)
+ receiver_type, _scope = calculate_type_scope.call target_node
+ [:call, receiver_type, name, false]
+ else
+ [:symbol, name] unless name.empty?
+ end
+ when Prism::CallNode
+ return [:lvar_or_method, name, calculate_scope.call] if target_node.receiver.nil?
+
+ self_call = target_node.receiver.is_a? Prism::SelfNode
+ op = target_node.call_operator
+ receiver_type, _scope = calculate_type_scope.call target_node.receiver
+ receiver_type = receiver_type.nonnillable if op == '&.'
+ [op == '::' ? :call_or_const : :call, receiver_type, name, self_call]
+ when Prism::LocalVariableReadNode, Prism::LocalVariableTargetNode
+ [:lvar_or_method, name, calculate_scope.call]
+ when Prism::ConstantReadNode, Prism::ConstantTargetNode
+ if parents.last.is_a? Prism::ConstantPathNode
+ path_node = parents.last
+ if path_node.parent # A::B
+ receiver, scope = calculate_type_scope.call(path_node.parent)
+ [:const, receiver, name, scope]
+ else # ::A
+ scope = calculate_scope.call
+ [:const, Types::SingletonType.new(Object), name, scope]
+ end
+ else
+ [:const, nil, name, calculate_scope.call]
+ end
+ when Prism::GlobalVariableReadNode, Prism::GlobalVariableTargetNode
+ [:gvar, name, calculate_scope.call]
+ when Prism::InstanceVariableReadNode, Prism::InstanceVariableTargetNode
+ [:ivar, name, calculate_scope.call]
+ when Prism::ClassVariableReadNode, Prism::ClassVariableTargetNode
+ [:cvar, name, calculate_scope.call]
+ end
+ end
+
+ def find_target(node, position)
+ location = (
+ case node
+ when Prism::CallNode
+ node.message_loc
+ when Prism::SymbolNode
+ node.value_loc
+ when Prism::StringNode
+ node.content_loc
+ when Prism::InterpolatedStringNode
+ node.closing_loc if node.parts.empty?
+ end
+ )
+ return [node] if location&.start_offset == position
+
+ node.compact_child_nodes.each do |n|
+ match = find_target(n, position)
+ next unless match
+ match.unshift node
+ return match
+ end
+
+ [node] if node.location.start_offset == position
+ end
+
+ def handle_error(e)
+ end
+ end
+ end
+end
diff --git a/lib/irb/type_completion/methods.rb b/lib/irb/type_completion/methods.rb
new file mode 100644
index 0000000000..8a88b6d0f9
--- /dev/null
+++ b/lib/irb/type_completion/methods.rb
@@ -0,0 +1,13 @@
+# frozen_string_literal: true
+
+module IRB
+ module TypeCompletion
+ module Methods
+ OBJECT_SINGLETON_CLASS_METHOD = Object.instance_method(:singleton_class)
+ OBJECT_INSTANCE_VARIABLES_METHOD = Object.instance_method(:instance_variables)
+ OBJECT_INSTANCE_VARIABLE_GET_METHOD = Object.instance_method(:instance_variable_get)
+ OBJECT_CLASS_METHOD = Object.instance_method(:class)
+ MODULE_NAME_METHOD = Module.instance_method(:name)
+ end
+ end
+end
diff --git a/lib/irb/type_completion/scope.rb b/lib/irb/type_completion/scope.rb
new file mode 100644
index 0000000000..5a58a0ed65
--- /dev/null
+++ b/lib/irb/type_completion/scope.rb
@@ -0,0 +1,412 @@
+# frozen_string_literal: true
+
+require 'set'
+require_relative 'types'
+
+module IRB
+ module TypeCompletion
+
+ class RootScope
+ attr_reader :module_nesting, :self_object
+
+ def initialize(binding, self_object, local_variables)
+ @binding = binding
+ @self_object = self_object
+ @cache = {}
+ modules = [*binding.eval('::Module.nesting'), Object]
+ @module_nesting = modules.map { [_1, []] }
+ binding_local_variables = binding.local_variables
+ uninitialized_locals = local_variables - binding_local_variables
+ uninitialized_locals.each { @cache[_1] = Types::NIL }
+ @local_variables = (local_variables | binding_local_variables).map(&:to_s).to_set
+ @global_variables = Kernel.global_variables.map(&:to_s).to_set
+ @owned_constants_cache = {}
+ end
+
+ def level() = 0
+
+ def level_of(_name, _var_type) = 0
+
+ def mutable?() = false
+
+ def module_own_constant?(mod, name)
+ set = (@owned_constants_cache[mod] ||= Set.new(mod.constants.map(&:to_s)))
+ set.include? name
+ end
+
+ def get_const(nesting, path, _key = nil)
+ return unless nesting
+
+ result = path.reduce nesting do |mod, name|
+ return nil unless mod.is_a?(Module) && module_own_constant?(mod, name)
+ mod.const_get name
+ end
+ Types.type_from_object result
+ end
+
+ def get_cvar(nesting, path, name, _key = nil)
+ return Types::NIL unless nesting
+
+ result = path.reduce nesting do |mod, n|
+ return Types::NIL unless mod.is_a?(Module) && module_own_constant?(mod, n)
+ mod.const_get n
+ end
+ value = result.class_variable_get name if result.is_a?(Module) && name.size >= 3 && result.class_variable_defined?(name)
+ Types.type_from_object value
+ end
+
+ def [](name)
+ @cache[name] ||= (
+ value = case RootScope.type_by_name name
+ when :ivar
+ begin
+ Methods::OBJECT_INSTANCE_VARIABLE_GET_METHOD.bind_call(@self_object, name)
+ rescue NameError
+ end
+ when :lvar
+ begin
+ @binding.local_variable_get(name)
+ rescue NameError
+ end
+ when :gvar
+ @binding.eval name if @global_variables.include? name
+ end
+ Types.type_from_object(value)
+ )
+ end
+
+ def self_type
+ Types.type_from_object @self_object
+ end
+
+ def local_variables() = @local_variables.to_a
+
+ def global_variables() = @global_variables.to_a
+
+ def self.type_by_name(name)
+ if name.start_with? '@@'
+ # "@@cvar" or "@@cvar::[module_id]::[module_path]"
+ :cvar
+ elsif name.start_with? '@'
+ :ivar
+ elsif name.start_with? '$'
+ :gvar
+ elsif name.start_with? '%'
+ :internal
+ elsif name[0].downcase != name[0] || name[0].match?(/\d/)
+ # "ConstName" or "[module_id]::[const_path]"
+ :const
+ else
+ :lvar
+ end
+ end
+ end
+
+ class Scope
+ BREAK_RESULT = '%break'
+ NEXT_RESULT = '%next'
+ RETURN_RESULT = '%return'
+ PATTERNMATCH_BREAK = '%match'
+
+ attr_reader :parent, :mergeable_changes, :level, :module_nesting
+
+ def self.from_binding(binding, locals) = new(RootScope.new(binding, binding.receiver, locals))
+
+ def initialize(parent, table = {}, trace_ivar: true, trace_lvar: true, self_type: nil, nesting: nil)
+ @parent = parent
+ @level = parent.level + 1
+ @trace_ivar = trace_ivar
+ @trace_lvar = trace_lvar
+ @module_nesting = nesting ? [nesting, *parent.module_nesting] : parent.module_nesting
+ @self_type = self_type
+ @terminated = false
+ @jump_branches = []
+ @mergeable_changes = @table = table.transform_values { [level, _1] }
+ end
+
+ def mutable? = true
+
+ def terminated?
+ @terminated
+ end
+
+ def terminate_with(type, value)
+ return if terminated?
+ store_jump type, value, @mergeable_changes
+ terminate
+ end
+
+ def store_jump(type, value, changes)
+ return if terminated?
+ if has_own?(type)
+ changes[type] = [level, value]
+ @jump_branches << changes
+ elsif @parent.mutable?
+ @parent.store_jump(type, value, changes)
+ end
+ end
+
+ def terminate
+ return if terminated?
+ @terminated = true
+ @table = @mergeable_changes.dup
+ end
+
+ def trace?(name)
+ return false unless @parent
+ type = RootScope.type_by_name(name)
+ type == :ivar ? @trace_ivar : type == :lvar ? @trace_lvar : true
+ end
+
+ def level_of(name, var_type)
+ case var_type
+ when :ivar
+ return level unless @trace_ivar
+ when :gvar
+ return 0
+ end
+ variable_level, = @table[name]
+ variable_level || parent.level_of(name, var_type)
+ end
+
+ def get_const(nesting, path, key = nil)
+ key ||= [nesting.__id__, path].join('::')
+ _l, value = @table[key]
+ value || @parent.get_const(nesting, path, key)
+ end
+
+ def get_cvar(nesting, path, name, key = nil)
+ key ||= [name, nesting.__id__, path].join('::')
+ _l, value = @table[key]
+ value || @parent.get_cvar(nesting, path, name, key)
+ end
+
+ def [](name)
+ type = RootScope.type_by_name(name)
+ if type == :const
+ return get_const(nil, nil, name) || Types::NIL if name.include?('::')
+
+ module_nesting.each do |(nesting, path)|
+ value = get_const nesting, [*path, name]
+ return value if value
+ end
+ return Types::NIL
+ elsif type == :cvar
+ return get_cvar(nil, nil, nil, name) if name.include?('::')
+
+ nesting, path = module_nesting.first
+ return get_cvar(nesting, path, name)
+ end
+ level, value = @table[name]
+ if level
+ value
+ elsif trace? name
+ @parent[name]
+ elsif type == :ivar
+ self_instance_variable_get name
+ end
+ end
+
+ def set_const(nesting, path, value)
+ key = [nesting.__id__, path].join('::')
+ @table[key] = [0, value]
+ end
+
+ def set_cvar(nesting, path, name, value)
+ key = [name, nesting.__id__, path].join('::')
+ @table[key] = [0, value]
+ end
+
+ def []=(name, value)
+ type = RootScope.type_by_name(name)
+ if type == :const
+ if name.include?('::')
+ @table[name] = [0, value]
+ else
+ parent_module, parent_path = module_nesting.first
+ set_const parent_module, [*parent_path, name], value
+ end
+ return
+ elsif type == :cvar
+ if name.include?('::')
+ @table[name] = [0, value]
+ else
+ parent_module, parent_path = module_nesting.first
+ set_cvar parent_module, parent_path, name, value
+ end
+ return
+ end
+ variable_level = level_of name, type
+ @table[name] = [variable_level, value] if variable_level
+ end
+
+ def self_type
+ @self_type || @parent.self_type
+ end
+
+ def global_variables
+ gvar_keys = @table.keys.select do |name|
+ RootScope.type_by_name(name) == :gvar
+ end
+ gvar_keys | @parent.global_variables
+ end
+
+ def local_variables
+ lvar_keys = @table.keys.select do |name|
+ RootScope.type_by_name(name) == :lvar
+ end
+ lvar_keys |= @parent.local_variables if @trace_lvar
+ lvar_keys
+ end
+
+ def table_constants
+ constants = module_nesting.flat_map do |mod, path|
+ prefix = [mod.__id__, *path].join('::') + '::'
+ @table.keys.select { _1.start_with? prefix }.map { _1.delete_prefix(prefix).split('::').first }
+ end.uniq
+ constants |= @parent.table_constants if @parent.mutable?
+ constants
+ end
+
+ def table_module_constants(mod)
+ prefix = "#{mod.__id__}::"
+ constants = @table.keys.select { _1.start_with? prefix }.map { _1.delete_prefix(prefix).split('::').first }
+ constants |= @parent.table_constants if @parent.mutable?
+ constants
+ end
+
+ def base_scope
+ @parent.mutable? ? @parent.base_scope : @parent
+ end
+
+ def table_instance_variables
+ ivars = @table.keys.select { RootScope.type_by_name(_1) == :ivar }
+ ivars |= @parent.table_instance_variables if @parent.mutable? && @trace_ivar
+ ivars
+ end
+
+ def instance_variables
+ self_singleton_types = self_type.types.grep(Types::SingletonType)
+ singleton_classes = self_type.types.grep(Types::InstanceType).map(&:klass).select(&:singleton_class?)
+ base_self = base_scope.self_object
+ self_instance_variables = singleton_classes.flat_map do |singleton_class|
+ if singleton_class.respond_to? :attached_object
+ Methods::OBJECT_INSTANCE_VARIABLES_METHOD.bind_call(singleton_class.attached_object).map(&:to_s)
+ elsif singleton_class == Methods::OBJECT_SINGLETON_CLASS_METHOD.bind_call(base_self)
+ Methods::OBJECT_INSTANCE_VARIABLES_METHOD.bind_call(base_self).map(&:to_s)
+ else
+ []
+ end
+ end
+ [
+ self_singleton_types.flat_map { _1.module_or_class.instance_variables.map(&:to_s) },
+ self_instance_variables || [],
+ table_instance_variables
+ ].inject(:|)
+ end
+
+ def self_instance_variable_get(name)
+ self_objects = self_type.types.grep(Types::SingletonType).map(&:module_or_class)
+ singleton_classes = self_type.types.grep(Types::InstanceType).map(&:klass).select(&:singleton_class?)
+ base_self = base_scope.self_object
+ singleton_classes.each do |singleton_class|
+ if singleton_class.respond_to? :attached_object
+ self_objects << singleton_class.attached_object
+ elsif singleton_class == base_self.singleton_class
+ self_objects << base_self
+ end
+ end
+ types = self_objects.map do |object|
+ value = begin
+ Methods::OBJECT_INSTANCE_VARIABLE_GET_METHOD.bind_call(object, name)
+ rescue NameError
+ end
+ Types.type_from_object value
+ end
+ Types::UnionType[*types]
+ end
+
+ def table_class_variables
+ cvars = @table.keys.filter_map { _1.split('::', 2).first if RootScope.type_by_name(_1) == :cvar }
+ cvars |= @parent.table_class_variables if @parent.mutable?
+ cvars
+ end
+
+ def class_variables
+ cvars = table_class_variables
+ m, = module_nesting.first
+ cvars |= m.class_variables.map(&:to_s) if m.is_a? Module
+ cvars
+ end
+
+ def constants
+ module_nesting.flat_map do |nest,|
+ nest.constants
+ end.map(&:to_s) | table_constants
+ end
+
+ def merge_jumps
+ if terminated?
+ @terminated = false
+ @table = @mergeable_changes
+ merge @jump_branches
+ @terminated = true
+ else
+ merge [*@jump_branches, {}]
+ end
+ end
+
+ def conditional(&block)
+ run_branches(block, ->(_s) {}).first || Types::NIL
+ end
+
+ def never(&block)
+ block.call Scope.new(self, { BREAK_RESULT => nil, NEXT_RESULT => nil, PATTERNMATCH_BREAK => nil, RETURN_RESULT => nil })
+ end
+
+ def run_branches(*blocks)
+ results = []
+ branches = []
+ blocks.each do |block|
+ scope = Scope.new self
+ result = block.call scope
+ next if scope.terminated?
+ results << result
+ branches << scope.mergeable_changes
+ end
+ terminate if branches.empty?
+ merge branches
+ results
+ end
+
+ def has_own?(name)
+ @table.key? name
+ end
+
+ def update(child_scope)
+ current_level = level
+ child_scope.mergeable_changes.each do |name, (level, value)|
+ self[name] = value if level <= current_level
+ end
+ end
+
+ protected
+
+ def merge(branches)
+ current_level = level
+ merge = {}
+ branches.each do |changes|
+ changes.each do |name, (level, value)|
+ next if current_level < level
+ (merge[name] ||= []) << value
+ end
+ end
+ merge.each do |name, values|
+ values << self[name] unless values.size == branches.size
+ values.compact!
+ self[name] = Types::UnionType[*values.compact] unless values.empty?
+ end
+ end
+ end
+ end
+end
diff --git a/lib/irb/type_completion/type_analyzer.rb b/lib/irb/type_completion/type_analyzer.rb
new file mode 100644
index 0000000000..c4a41e4999
--- /dev/null
+++ b/lib/irb/type_completion/type_analyzer.rb
@@ -0,0 +1,1169 @@
+# frozen_string_literal: true
+
+require 'set'
+require_relative 'types'
+require_relative 'scope'
+require 'prism'
+
+module IRB
+ module TypeCompletion
+ class TypeAnalyzer
+ class DigTarget
+ def initialize(parents, receiver, &block)
+ @dig_ids = parents.to_h { [_1.__id__, true] }
+ @target_id = receiver.__id__
+ @block = block
+ end
+
+ def dig?(node) = @dig_ids[node.__id__]
+ def target?(node) = @target_id == node.__id__
+ def resolve(type, scope)
+ @block.call type, scope
+ end
+ end
+
+ OBJECT_METHODS = {
+ to_s: Types::STRING,
+ to_str: Types::STRING,
+ to_a: Types::ARRAY,
+ to_ary: Types::ARRAY,
+ to_h: Types::HASH,
+ to_hash: Types::HASH,
+ to_i: Types::INTEGER,
+ to_int: Types::INTEGER,
+ to_f: Types::FLOAT,
+ to_c: Types::COMPLEX,
+ to_r: Types::RATIONAL
+ }
+
+ def initialize(dig_targets)
+ @dig_targets = dig_targets
+ end
+
+ def evaluate(node, scope)
+ method = "evaluate_#{node.type}"
+ if respond_to? method
+ result = send method, node, scope
+ else
+ result = Types::NIL
+ end
+ @dig_targets.resolve result, scope if @dig_targets.target? node
+ result
+ end
+
+ def evaluate_program_node(node, scope)
+ evaluate node.statements, scope
+ end
+
+ def evaluate_statements_node(node, scope)
+ if node.body.empty?
+ Types::NIL
+ else
+ node.body.map { evaluate _1, scope }.last
+ end
+ end
+
+ def evaluate_def_node(node, scope)
+ if node.receiver
+ self_type = evaluate node.receiver, scope
+ else
+ current_self_types = scope.self_type.types
+ self_types = current_self_types.map do |type|
+ if type.is_a?(Types::SingletonType) && type.module_or_class.is_a?(Class)
+ Types::InstanceType.new type.module_or_class
+ else
+ type
+ end
+ end
+ self_type = Types::UnionType[*self_types]
+ end
+ if @dig_targets.dig?(node.body) || @dig_targets.dig?(node.parameters)
+ params_table = node.locals.to_h { [_1.to_s, Types::NIL] }
+ method_scope = Scope.new(
+ scope,
+ { **params_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil },
+ self_type: self_type,
+ trace_lvar: false,
+ trace_ivar: false
+ )
+ if node.parameters
+ # node.parameters is Prism::ParametersNode
+ assign_parameters node.parameters, method_scope, [], {}
+ end
+
+ if @dig_targets.dig?(node.body)
+ method_scope.conditional do |s|
+ evaluate node.body, s
+ end
+ end
+ method_scope.merge_jumps
+ scope.update method_scope
+ end
+ Types::SYMBOL
+ end
+
+ def evaluate_integer_node(_node, _scope) = Types::INTEGER
+
+ def evaluate_float_node(_node, _scope) = Types::FLOAT
+
+ def evaluate_rational_node(_node, _scope) = Types::RATIONAL
+
+ def evaluate_imaginary_node(_node, _scope) = Types::COMPLEX
+
+ def evaluate_string_node(_node, _scope) = Types::STRING
+
+ def evaluate_x_string_node(_node, _scope)
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_symbol_node(_node, _scope) = Types::SYMBOL
+
+ def evaluate_regular_expression_node(_node, _scope) = Types::REGEXP
+
+ def evaluate_string_concat_node(node, scope)
+ evaluate node.left, scope
+ evaluate node.right, scope
+ Types::STRING
+ end
+
+ def evaluate_interpolated_string_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::STRING
+ end
+
+ def evaluate_interpolated_x_string_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::STRING
+ end
+
+ def evaluate_interpolated_symbol_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::SYMBOL
+ end
+
+ def evaluate_interpolated_regular_expression_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::REGEXP
+ end
+
+ def evaluate_embedded_statements_node(node, scope)
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ Types::STRING
+ end
+
+ def evaluate_embedded_variable_node(node, scope)
+ evaluate node.variable, scope
+ Types::STRING
+ end
+
+ def evaluate_array_node(node, scope)
+ Types.array_of evaluate_list_splat_items(node.elements, scope)
+ end
+
+ def evaluate_hash_node(node, scope) = evaluate_hash(node, scope)
+ def evaluate_keyword_hash_node(node, scope) = evaluate_hash(node, scope)
+ def evaluate_hash(node, scope)
+ keys = []
+ values = []
+ node.elements.each do |assoc|
+ case assoc
+ when Prism::AssocNode
+ keys << evaluate(assoc.key, scope)
+ values << evaluate(assoc.value, scope)
+ when Prism::AssocSplatNode
+ next unless assoc.value # def f(**); {**}
+
+ hash = evaluate assoc.value, scope
+ unless hash.is_a?(Types::InstanceType) && hash.klass == Hash
+ hash = method_call hash, :to_hash, [], nil, nil, scope
+ end
+ if hash.is_a?(Types::InstanceType) && hash.klass == Hash
+ keys << hash.params[:K] if hash.params[:K]
+ values << hash.params[:V] if hash.params[:V]
+ end
+ end
+ end
+ if keys.empty? && values.empty?
+ Types::InstanceType.new Hash
+ else
+ Types::InstanceType.new Hash, K: Types::UnionType[*keys], V: Types::UnionType[*values]
+ end
+ end
+
+ def evaluate_parentheses_node(node, scope)
+ node.body ? evaluate(node.body, scope) : Types::NIL
+ end
+
+ def evaluate_constant_path_node(node, scope)
+ type, = evaluate_constant_node_info node, scope
+ type
+ end
+
+ def evaluate_self_node(_node, scope) = scope.self_type
+
+ def evaluate_true_node(_node, _scope) = Types::TRUE
+
+ def evaluate_false_node(_node, _scope) = Types::FALSE
+
+ def evaluate_nil_node(_node, _scope) = Types::NIL
+
+ def evaluate_source_file_node(_node, _scope) = Types::STRING
+
+ def evaluate_source_line_node(_node, _scope) = Types::INTEGER
+
+ def evaluate_source_encoding_node(_node, _scope) = Types::InstanceType.new(Encoding)
+
+ def evaluate_numbered_reference_read_node(_node, _scope)
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_back_reference_read_node(_node, _scope)
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_reference_read(node, scope)
+ scope[node.name.to_s] || Types::NIL
+ end
+ alias evaluate_constant_read_node evaluate_reference_read
+ alias evaluate_global_variable_read_node evaluate_reference_read
+ alias evaluate_local_variable_read_node evaluate_reference_read
+ alias evaluate_class_variable_read_node evaluate_reference_read
+ alias evaluate_instance_variable_read_node evaluate_reference_read
+
+
+ def evaluate_call_node(node, scope)
+ is_field_assign = node.name.match?(/[^<>=!\]]=\z/) || (node.name == :[]= && !node.call_operator)
+ receiver_type = node.receiver ? evaluate(node.receiver, scope) : scope.self_type
+ evaluate_method = lambda do |scope|
+ args_types, kwargs_types, block_sym_node, has_block = evaluate_call_node_arguments node, scope
+
+ if block_sym_node
+ block_sym = block_sym_node.value
+ if @dig_targets.target? block_sym_node
+ # method(args, &:completion_target)
+ call_block_proc = ->(block_args, _self_type) do
+ block_receiver = block_args.first || Types::OBJECT
+ @dig_targets.resolve block_receiver, scope
+ Types::OBJECT
+ end
+ else
+ call_block_proc = ->(block_args, _self_type) do
+ block_receiver, *rest = block_args
+ block_receiver ? method_call(block_receiver || Types::OBJECT, block_sym, rest, nil, nil, scope) : Types::OBJECT
+ end
+ end
+ elsif node.block.is_a? Prism::BlockNode
+ call_block_proc = ->(block_args, block_self_type) do
+ scope.conditional do |s|
+ numbered_parameters = node.block.locals.grep(/\A_[1-9]/).map(&:to_s)
+ params_table = node.block.locals.to_h { [_1.to_s, Types::NIL] }
+ table = { **params_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil }
+ block_scope = Scope.new s, table, self_type: block_self_type, trace_ivar: !block_self_type
+ # TODO kwargs
+ if node.block.parameters&.parameters
+ # node.block.parameters is Prism::BlockParametersNode
+ assign_parameters node.block.parameters.parameters, block_scope, block_args, {}
+ elsif !numbered_parameters.empty?
+ assign_numbered_parameters numbered_parameters, block_scope, block_args, {}
+ end
+ result = node.block.body ? evaluate(node.block.body, block_scope) : Types::NIL
+ block_scope.merge_jumps
+ s.update block_scope
+ nexts = block_scope[Scope::NEXT_RESULT]
+ breaks = block_scope[Scope::BREAK_RESULT]
+ if block_scope.terminated?
+ [Types::UnionType[*nexts], breaks]
+ else
+ [Types::UnionType[result, *nexts], breaks]
+ end
+ end
+ end
+ elsif has_block
+ call_block_proc = ->(_block_args, _self_type) { Types::OBJECT }
+ end
+ result = method_call receiver_type, node.name, args_types, kwargs_types, call_block_proc, scope
+ if is_field_assign
+ args_types.last || Types::NIL
+ else
+ result
+ end
+ end
+ if node.call_operator == '&.'
+ result = scope.conditional { evaluate_method.call _1 }
+ if receiver_type.nillable?
+ Types::UnionType[result, Types::NIL]
+ else
+ result
+ end
+ else
+ evaluate_method.call scope
+ end
+ end
+
+ def evaluate_and_node(node, scope) = evaluate_and_or(node, scope, and_op: true)
+ def evaluate_or_node(node, scope) = evaluate_and_or(node, scope, and_op: false)
+ def evaluate_and_or(node, scope, and_op:)
+ left = evaluate node.left, scope
+ right = scope.conditional { evaluate node.right, _1 }
+ if and_op
+ Types::UnionType[right, Types::NIL, Types::FALSE]
+ else
+ Types::UnionType[left, right]
+ end
+ end
+
+ def evaluate_call_operator_write_node(node, scope) = evaluate_call_write(node, scope, :operator, node.write_name)
+ def evaluate_call_and_write_node(node, scope) = evaluate_call_write(node, scope, :and, node.write_name)
+ def evaluate_call_or_write_node(node, scope) = evaluate_call_write(node, scope, :or, node.write_name)
+ def evaluate_index_operator_write_node(node, scope) = evaluate_call_write(node, scope, :operator, :[]=)
+ def evaluate_index_and_write_node(node, scope) = evaluate_call_write(node, scope, :and, :[]=)
+ def evaluate_index_or_write_node(node, scope) = evaluate_call_write(node, scope, :or, :[]=)
+ def evaluate_call_write(node, scope, operator, write_name)
+ receiver_type = evaluate node.receiver, scope
+ if write_name == :[]=
+ args_types, kwargs_types, block_sym_node, has_block = evaluate_call_node_arguments node, scope
+ else
+ args_types = []
+ end
+ if block_sym_node
+ block_sym = block_sym_node.value
+ call_block_proc = ->(block_args, _self_type) do
+ block_receiver, *rest = block_args
+ block_receiver ? method_call(block_receiver || Types::OBJECT, block_sym, rest, nil, nil, scope) : Types::OBJECT
+ end
+ elsif has_block
+ call_block_proc = ->(_block_args, _self_type) { Types::OBJECT }
+ end
+ method = write_name.to_s.delete_suffix('=')
+ left = method_call receiver_type, method, args_types, kwargs_types, call_block_proc, scope
+ case operator
+ when :and
+ right = scope.conditional { evaluate node.value, _1 }
+ Types::UnionType[right, Types::NIL, Types::FALSE]
+ when :or
+ right = scope.conditional { evaluate node.value, _1 }
+ Types::UnionType[left, right]
+ else
+ right = evaluate node.value, scope
+ method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ end
+ end
+
+ def evaluate_variable_operator_write(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = evaluate node.value, scope
+ scope[node.name.to_s] = method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ end
+ alias evaluate_global_variable_operator_write_node evaluate_variable_operator_write
+ alias evaluate_local_variable_operator_write_node evaluate_variable_operator_write
+ alias evaluate_class_variable_operator_write_node evaluate_variable_operator_write
+ alias evaluate_instance_variable_operator_write_node evaluate_variable_operator_write
+
+ def evaluate_variable_and_write(node, scope)
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[right, Types::NIL, Types::FALSE]
+ end
+ alias evaluate_global_variable_and_write_node evaluate_variable_and_write
+ alias evaluate_local_variable_and_write_node evaluate_variable_and_write
+ alias evaluate_class_variable_and_write_node evaluate_variable_and_write
+ alias evaluate_instance_variable_and_write_node evaluate_variable_and_write
+
+ def evaluate_variable_or_write(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[left, right]
+ end
+ alias evaluate_global_variable_or_write_node evaluate_variable_or_write
+ alias evaluate_local_variable_or_write_node evaluate_variable_or_write
+ alias evaluate_class_variable_or_write_node evaluate_variable_or_write
+ alias evaluate_instance_variable_or_write_node evaluate_variable_or_write
+
+ def evaluate_constant_operator_write_node(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = evaluate node.value, scope
+ scope[node.name.to_s] = method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ end
+
+ def evaluate_constant_and_write_node(node, scope)
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[right, Types::NIL, Types::FALSE]
+ end
+
+ def evaluate_constant_or_write_node(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[left, right]
+ end
+
+ def evaluate_constant_path_operator_write_node(node, scope)
+ left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope
+ right = evaluate node.value, scope
+ value = method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ const_path_write receiver, name, value, scope
+ value
+ end
+
+ def evaluate_constant_path_and_write_node(node, scope)
+ _left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope
+ right = scope.conditional { evaluate node.value, scope }
+ value = Types::UnionType[right, Types::NIL, Types::FALSE]
+ const_path_write receiver, name, value, scope
+ value
+ end
+
+ def evaluate_constant_path_or_write_node(node, scope)
+ left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope
+ right = scope.conditional { evaluate node.value, scope }
+ value = Types::UnionType[left, right]
+ const_path_write receiver, name, value, scope
+ value
+ end
+
+ def evaluate_constant_path_write_node(node, scope)
+ receiver = evaluate node.target.parent, scope if node.target.parent
+ value = evaluate node.value, scope
+ const_path_write receiver, node.target.child.name.to_s, value, scope
+ value
+ end
+
+ def evaluate_lambda_node(node, scope)
+ local_table = node.locals.to_h { [_1.to_s, Types::OBJECT] }
+ block_scope = Scope.new scope, { **local_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil }
+ block_scope.conditional do |s|
+ assign_parameters node.parameters.parameters, s, [], {} if node.parameters&.parameters
+ evaluate node.body, s if node.body
+ end
+ block_scope.merge_jumps
+ scope.update block_scope
+ Types::PROC
+ end
+
+ def evaluate_reference_write(node, scope)
+ scope[node.name.to_s] = evaluate node.value, scope
+ end
+ alias evaluate_constant_write_node evaluate_reference_write
+ alias evaluate_global_variable_write_node evaluate_reference_write
+ alias evaluate_local_variable_write_node evaluate_reference_write
+ alias evaluate_class_variable_write_node evaluate_reference_write
+ alias evaluate_instance_variable_write_node evaluate_reference_write
+
+ def evaluate_multi_write_node(node, scope)
+ evaluated_receivers = {}
+ evaluate_multi_write_receiver node, scope, evaluated_receivers
+ value = (
+ if node.value.is_a? Prism::ArrayNode
+ if node.value.elements.any?(Prism::SplatNode)
+ evaluate node.value, scope
+ else
+ node.value.elements.map do |n|
+ evaluate n, scope
+ end
+ end
+ elsif node.value
+ evaluate node.value, scope
+ else
+ Types::NIL
+ end
+ )
+ evaluate_multi_write node, value, scope, evaluated_receivers
+ value.is_a?(Array) ? Types.array_of(*value) : value
+ end
+
+ def evaluate_if_node(node, scope) = evaluate_if_unless(node, scope)
+ def evaluate_unless_node(node, scope) = evaluate_if_unless(node, scope)
+ def evaluate_if_unless(node, scope)
+ evaluate node.predicate, scope
+ Types::UnionType[*scope.run_branches(
+ -> { node.statements ? evaluate(node.statements, _1) : Types::NIL },
+ -> { node.consequent ? evaluate(node.consequent, _1) : Types::NIL }
+ )]
+ end
+
+ def evaluate_else_node(node, scope)
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ end
+
+ def evaluate_while_until(node, scope)
+ inner_scope = Scope.new scope, { Scope::BREAK_RESULT => nil }
+ evaluate node.predicate, inner_scope
+ if node.statements
+ inner_scope.conditional do |s|
+ evaluate node.statements, s
+ end
+ end
+ inner_scope.merge_jumps
+ scope.update inner_scope
+ breaks = inner_scope[Scope::BREAK_RESULT]
+ breaks ? Types::UnionType[breaks, Types::NIL] : Types::NIL
+ end
+ alias evaluate_while_node evaluate_while_until
+ alias evaluate_until_node evaluate_while_until
+
+ def evaluate_break_node(node, scope) = evaluate_jump(node, scope, :break)
+ def evaluate_next_node(node, scope) = evaluate_jump(node, scope, :next)
+ def evaluate_return_node(node, scope) = evaluate_jump(node, scope, :return)
+ def evaluate_jump(node, scope, mode)
+ internal_key = (
+ case mode
+ when :break
+ Scope::BREAK_RESULT
+ when :next
+ Scope::NEXT_RESULT
+ when :return
+ Scope::RETURN_RESULT
+ end
+ )
+ jump_value = (
+ arguments = node.arguments&.arguments
+ if arguments.nil? || arguments.empty?
+ Types::NIL
+ elsif arguments.size == 1 && !arguments.first.is_a?(Prism::SplatNode)
+ evaluate arguments.first, scope
+ else
+ Types.array_of evaluate_list_splat_items(arguments, scope)
+ end
+ )
+ scope.terminate_with internal_key, jump_value
+ Types::NIL
+ end
+
+ def evaluate_yield_node(node, scope)
+ evaluate_list_splat_items node.arguments.arguments, scope if node.arguments
+ Types::OBJECT
+ end
+
+ def evaluate_redo_node(_node, scope)
+ scope.terminate
+ Types::NIL
+ end
+
+ def evaluate_retry_node(_node, scope)
+ scope.terminate
+ Types::NIL
+ end
+
+ def evaluate_forwarding_super_node(_node, _scope) = Types::OBJECT
+
+ def evaluate_super_node(node, scope)
+ evaluate_list_splat_items node.arguments.arguments, scope if node.arguments
+ Types::OBJECT
+ end
+
+ def evaluate_begin_node(node, scope)
+ return_type = node.statements ? evaluate(node.statements, scope) : Types::NIL
+ if node.rescue_clause
+ if node.else_clause
+ return_types = scope.run_branches(
+ ->{ evaluate node.rescue_clause, _1 },
+ ->{ evaluate node.else_clause, _1 }
+ )
+ else
+ return_types = [
+ return_type,
+ scope.conditional { evaluate node.rescue_clause, _1 }
+ ]
+ end
+ return_type = Types::UnionType[*return_types]
+ end
+ if node.ensure_clause&.statements
+ # ensure_clause is Prism::EnsureNode
+ evaluate node.ensure_clause.statements, scope
+ end
+ return_type
+ end
+
+ def evaluate_rescue_node(node, scope)
+ run_rescue = lambda do |s|
+ if node.reference
+ error_classes_type = evaluate_list_splat_items node.exceptions, s
+ error_types = error_classes_type.types.filter_map do
+ Types::InstanceType.new _1.module_or_class if _1.is_a?(Types::SingletonType)
+ end
+ error_types << Types::InstanceType.new(StandardError) if error_types.empty?
+ error_type = Types::UnionType[*error_types]
+ case node.reference
+ when Prism::LocalVariableTargetNode, Prism::InstanceVariableTargetNode, Prism::ClassVariableTargetNode, Prism::GlobalVariableTargetNode, Prism::ConstantTargetNode
+ s[node.reference.name.to_s] = error_type
+ when Prism::CallNode
+ evaluate node.reference, s
+ end
+ end
+ node.statements ? evaluate(node.statements, s) : Types::NIL
+ end
+ if node.consequent # begin; rescue A; rescue B; end
+ types = scope.run_branches(
+ run_rescue,
+ -> { evaluate node.consequent, _1 }
+ )
+ Types::UnionType[*types]
+ else
+ run_rescue.call scope
+ end
+ end
+
+ def evaluate_rescue_modifier_node(node, scope)
+ a = evaluate node.expression, scope
+ b = scope.conditional { evaluate node.rescue_expression, _1 }
+ Types::UnionType[a, b]
+ end
+
+ def evaluate_singleton_class_node(node, scope)
+ klass_types = evaluate(node.expression, scope).types.filter_map do |type|
+ Types::SingletonType.new type.klass if type.is_a? Types::InstanceType
+ end
+ klass_types = [Types::CLASS] if klass_types.empty?
+ table = node.locals.to_h { [_1.to_s, Types::NIL] }
+ sclass_scope = Scope.new(
+ scope,
+ { **table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil },
+ trace_ivar: false,
+ trace_lvar: false,
+ self_type: Types::UnionType[*klass_types]
+ )
+ result = node.body ? evaluate(node.body, sclass_scope) : Types::NIL
+ scope.update sclass_scope
+ result
+ end
+
+ def evaluate_class_node(node, scope) = evaluate_class_module(node, scope, true)
+ def evaluate_module_node(node, scope) = evaluate_class_module(node, scope, false)
+ def evaluate_class_module(node, scope, is_class)
+ unless node.constant_path.is_a?(Prism::ConstantReadNode) || node.constant_path.is_a?(Prism::ConstantPathNode)
+ # Incomplete class/module `class (statement[cursor_here])::Name; end`
+ evaluate node.constant_path, scope
+ return Types::NIL
+ end
+ const_type, _receiver, parent_module, name = evaluate_constant_node_info node.constant_path, scope
+ if is_class
+ select_class_type = -> { _1.is_a?(Types::SingletonType) && _1.module_or_class.is_a?(Class) }
+ module_types = const_type.types.select(&select_class_type)
+ module_types += evaluate(node.superclass, scope).types.select(&select_class_type) if node.superclass
+ module_types << Types::CLASS if module_types.empty?
+ else
+ module_types = const_type.types.select { _1.is_a?(Types::SingletonType) && !_1.module_or_class.is_a?(Class) }
+ module_types << Types::MODULE if module_types.empty?
+ end
+ return Types::NIL unless node.body
+
+ table = node.locals.to_h { [_1.to_s, Types::NIL] }
+ if !name.empty? && (parent_module.is_a?(Module) || parent_module.nil?)
+ value = parent_module.const_get name if parent_module&.const_defined? name
+ unless value
+ value_type = scope[name]
+ value = value_type.module_or_class if value_type.is_a? Types::SingletonType
+ end
+
+ if value.is_a? Module
+ nesting = [value, []]
+ else
+ if parent_module
+ nesting = [parent_module, [name]]
+ else
+ parent_nesting, parent_path = scope.module_nesting.first
+ nesting = [parent_nesting, parent_path + [name]]
+ end
+ nesting_key = [nesting[0].__id__, nesting[1]].join('::')
+ nesting_value = is_class ? Types::CLASS : Types::MODULE
+ end
+ else
+ # parent_module == :unknown
+ # TODO: dummy module
+ end
+ module_scope = Scope.new(
+ scope,
+ { **table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil },
+ trace_ivar: false,
+ trace_lvar: false,
+ self_type: Types::UnionType[*module_types],
+ nesting: nesting
+ )
+ module_scope[nesting_key] = nesting_value if nesting_value
+ result = evaluate(node.body, module_scope)
+ scope.update module_scope
+ result
+ end
+
+ def evaluate_for_node(node, scope)
+ node.statements
+ collection = evaluate node.collection, scope
+ inner_scope = Scope.new scope, { Scope::BREAK_RESULT => nil }
+ ary_type = method_call collection, :to_ary, [], nil, nil, nil, name_match: false
+ element_types = ary_type.types.filter_map do |ary|
+ ary.params[:Elem] if ary.is_a?(Types::InstanceType) && ary.klass == Array
+ end
+ element_type = Types::UnionType[*element_types]
+ inner_scope.conditional do |s|
+ evaluate_write node.index, element_type, s, nil
+ evaluate node.statements, s if node.statements
+ end
+ inner_scope.merge_jumps
+ scope.update inner_scope
+ breaks = inner_scope[Scope::BREAK_RESULT]
+ breaks ? Types::UnionType[breaks, collection] : collection
+ end
+
+ def evaluate_case_node(node, scope)
+ target = evaluate(node.predicate, scope) if node.predicate
+ # TODO
+ branches = node.conditions.map do |condition|
+ ->(s) { evaluate_case_match target, condition, s }
+ end
+ if node.consequent
+ branches << ->(s) { evaluate node.consequent, s }
+ elsif node.conditions.any? { _1.is_a? Prism::WhenNode }
+ branches << ->(_s) { Types::NIL }
+ end
+ Types::UnionType[*scope.run_branches(*branches)]
+ end
+
+ def evaluate_match_required_node(node, scope)
+ value_type = evaluate node.value, scope
+ evaluate_match_pattern value_type, node.pattern, scope
+ Types::NIL # void value
+ end
+
+ def evaluate_match_predicate_node(node, scope)
+ value_type = evaluate node.value, scope
+ scope.conditional { evaluate_match_pattern value_type, node.pattern, _1 }
+ Types::BOOLEAN
+ end
+
+ def evaluate_range_node(node, scope)
+ beg_type = evaluate node.left, scope if node.left
+ end_type = evaluate node.right, scope if node.right
+ elem = (Types::UnionType[*[beg_type, end_type].compact]).nonnillable
+ Types::InstanceType.new Range, Elem: elem
+ end
+
+ def evaluate_defined_node(node, scope)
+ scope.conditional { evaluate node.value, _1 }
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_flip_flop_node(node, scope)
+ scope.conditional { evaluate node.left, _1 } if node.left
+ scope.conditional { evaluate node.right, _1 } if node.right
+ Types::BOOLEAN
+ end
+
+ def evaluate_multi_target_node(node, scope)
+ # Raw MultiTargetNode, incomplete code like `a,b`, `*a`.
+ evaluate_multi_write_receiver node, scope, nil
+ Types::NIL
+ end
+
+ def evaluate_splat_node(node, scope)
+ # Raw SplatNode, incomplete code like `*a.`
+ evaluate_multi_write_receiver node.expression, scope, nil if node.expression
+ Types::NIL
+ end
+
+ def evaluate_implicit_node(node, scope)
+ evaluate node.value, scope
+ end
+
+ def evaluate_match_write_node(node, scope)
+ # /(?<a>)(?<b>)/ =~ string
+ evaluate node.call, scope
+ node.locals.each { scope[_1.to_s] = Types::UnionType[Types::STRING, Types::NIL] }
+ Types::BOOLEAN
+ end
+
+ def evaluate_match_last_line_node(_node, _scope)
+ Types::BOOLEAN
+ end
+
+ def evaluate_interpolated_match_last_line_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::BOOLEAN
+ end
+
+ def evaluate_pre_execution_node(node, scope)
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ end
+
+ def evaluate_post_execution_node(node, scope)
+ node.statements && @dig_targets.dig?(node.statements) ? evaluate(node.statements, scope) : Types::NIL
+ end
+
+ def evaluate_alias_method_node(_node, _scope) = Types::NIL
+ def evaluate_alias_global_variable_node(_node, _scope) = Types::NIL
+ def evaluate_undef_node(_node, _scope) = Types::NIL
+ def evaluate_missing_node(_node, _scope) = Types::NIL
+
+ def evaluate_call_node_arguments(call_node, scope)
+ # call_node.arguments is Prism::ArgumentsNode
+ arguments = call_node.arguments&.arguments&.dup || []
+ block_arg = call_node.block.expression if call_node.block.is_a?(Prism::BlockArgumentNode)
+ kwargs = arguments.pop.elements if arguments.last.is_a?(Prism::KeywordHashNode)
+ args_types = arguments.map do |arg|
+ case arg
+ when Prism::ForwardingArgumentsNode
+ # `f(a, ...)` treat like splat
+ nil
+ when Prism::SplatNode
+ evaluate arg.expression, scope if arg.expression
+ nil # TODO: splat
+ else
+ evaluate arg, scope
+ end
+ end
+ if kwargs
+ kwargs_types = kwargs.map do |arg|
+ case arg
+ when Prism::AssocNode
+ if arg.key.is_a?(Prism::SymbolNode)
+ [arg.key.value, evaluate(arg.value, scope)]
+ else
+ evaluate arg.key, scope
+ evaluate arg.value, scope
+ nil
+ end
+ when Prism::AssocSplatNode
+ evaluate arg.value, scope if arg.value
+ nil
+ end
+ end.compact.to_h
+ end
+ if block_arg.is_a? Prism::SymbolNode
+ block_sym_node = block_arg
+ elsif block_arg
+ evaluate block_arg, scope
+ end
+ [args_types, kwargs_types, block_sym_node, !!block_arg]
+ end
+
+ def const_path_write(receiver, name, value, scope)
+ if receiver # receiver::A = value
+ singleton_type = receiver.types.find { _1.is_a? Types::SingletonType }
+ scope.set_const singleton_type.module_or_class, name, value if singleton_type
+ else # ::A = value
+ scope.set_const Object, name, value
+ end
+ end
+
+ def assign_required_parameter(node, value, scope)
+ case node
+ when Prism::RequiredParameterNode
+ scope[node.name.to_s] = value || Types::OBJECT
+ when Prism::MultiTargetNode
+ parameters = [*node.lefts, *node.rest, *node.rights]
+ values = value ? sized_splat(value, :to_ary, parameters.size) : []
+ parameters.zip values do |n, v|
+ assign_required_parameter n, v, scope
+ end
+ when Prism::SplatNode
+ splat_value = value ? Types.array_of(value) : Types::ARRAY
+ assign_required_parameter node.expression, splat_value, scope if node.expression
+ end
+ end
+
+ def evaluate_constant_node_info(node, scope)
+ case node
+ when Prism::ConstantPathNode
+ name = node.child.name.to_s
+ if node.parent
+ receiver = evaluate node.parent, scope
+ if receiver.is_a? Types::SingletonType
+ parent_module = receiver.module_or_class
+ end
+ else
+ parent_module = Object
+ end
+ if parent_module
+ type = scope.get_const(parent_module, [name]) || Types::NIL
+ else
+ parent_module = :unknown
+ type = Types::NIL
+ end
+ when Prism::ConstantReadNode
+ name = node.name.to_s
+ type = scope[name]
+ end
+ @dig_targets.resolve type, scope if @dig_targets.target? node
+ [type, receiver, parent_module, name]
+ end
+
+
+ def assign_parameters(node, scope, args, kwargs)
+ args = args.dup
+ kwargs = kwargs.dup
+ size = node.requireds.size + node.optionals.size + (node.rest ? 1 : 0) + node.posts.size
+ args = sized_splat(args.first, :to_ary, size) if size >= 2 && args.size == 1
+ reqs = args.shift node.requireds.size
+ if node.rest
+ # node.rest is Prism::RestParameterNode
+ posts = []
+ opts = args.shift node.optionals.size
+ rest = args
+ else
+ posts = args.pop node.posts.size
+ opts = args
+ rest = []
+ end
+ node.requireds.zip reqs do |n, v|
+ assign_required_parameter n, v, scope
+ end
+ node.optionals.zip opts do |n, v|
+ # n is Prism::OptionalParameterNode
+ values = [v]
+ values << evaluate(n.value, scope) if n.value
+ scope[n.name.to_s] = Types::UnionType[*values.compact]
+ end
+ node.posts.zip posts do |n, v|
+ assign_required_parameter n, v, scope
+ end
+ if node.rest&.name
+ # node.rest is Prism::RestParameterNode
+ scope[node.rest.name.to_s] = Types.array_of(*rest)
+ end
+ node.keywords.each do |n|
+ name = n.name.to_s.delete(':')
+ values = [kwargs.delete(name)]
+ # n is Prism::OptionalKeywordParameterNode (has n.value) or Prism::RequiredKeywordParameterNode (does not have n.value)
+ values << evaluate(n.value, scope) if n.respond_to?(:value)
+ scope[name] = Types::UnionType[*values.compact]
+ end
+ # node.keyword_rest is Prism::KeywordRestParameterNode or Prism::ForwardingParameterNode or Prism::NoKeywordsParameterNode
+ if node.keyword_rest.is_a?(Prism::KeywordRestParameterNode) && node.keyword_rest.name
+ scope[node.keyword_rest.name.to_s] = Types::InstanceType.new(Hash, K: Types::SYMBOL, V: Types::UnionType[*kwargs.values])
+ end
+ if node.block&.name
+ # node.block is Prism::BlockParameterNode
+ scope[node.block.name.to_s] = Types::PROC
+ end
+ end
+
+ def assign_numbered_parameters(numbered_parameters, scope, args, _kwargs)
+ return if numbered_parameters.empty?
+ max_num = numbered_parameters.map { _1[1].to_i }.max
+ if max_num == 1
+ scope['_1'] = args.first || Types::NIL
+ else
+ args = sized_splat(args.first, :to_ary, max_num) if args.size == 1
+ numbered_parameters.each do |name|
+ index = name[1].to_i - 1
+ scope[name] = args[index] || Types::NIL
+ end
+ end
+ end
+
+ def evaluate_case_match(target, node, scope)
+ case node
+ when Prism::WhenNode
+ node.conditions.each { evaluate _1, scope }
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ when Prism::InNode
+ pattern = node.pattern
+ if pattern.is_a?(Prism::IfNode) || pattern.is_a?(Prism::UnlessNode)
+ cond_node = pattern.predicate
+ pattern = pattern.statements.body.first
+ end
+ evaluate_match_pattern(target, pattern, scope)
+ evaluate cond_node, scope if cond_node # TODO: conditional branch
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ end
+ end
+
+ def evaluate_match_pattern(value, pattern, scope)
+ # TODO: scope.terminate_with Scope::PATTERNMATCH_BREAK, Types::NIL
+ case pattern
+ when Prism::FindPatternNode
+ # TODO
+ evaluate_match_pattern Types::OBJECT, pattern.left, scope
+ pattern.requireds.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ evaluate_match_pattern Types::OBJECT, pattern.right, scope
+ when Prism::ArrayPatternNode
+ # TODO
+ pattern.requireds.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ evaluate_match_pattern Types::OBJECT, pattern.rest, scope if pattern.rest
+ pattern.posts.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ Types::ARRAY
+ when Prism::HashPatternNode
+ # TODO
+ pattern.elements.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ if pattern.respond_to?(:rest) && pattern.rest
+ evaluate_match_pattern Types::OBJECT, pattern.rest, scope
+ end
+ Types::HASH
+ when Prism::AssocNode
+ evaluate_match_pattern value, pattern.value, scope if pattern.value
+ Types::OBJECT
+ when Prism::AssocSplatNode
+ # TODO
+ evaluate_match_pattern Types::HASH, pattern.value, scope
+ Types::OBJECT
+ when Prism::PinnedVariableNode
+ evaluate pattern.variable, scope
+ when Prism::PinnedExpressionNode
+ evaluate pattern.expression, scope
+ when Prism::LocalVariableTargetNode
+ scope[pattern.name.to_s] = value
+ when Prism::AlternationPatternNode
+ Types::UnionType[evaluate_match_pattern(value, pattern.left, scope), evaluate_match_pattern(value, pattern.right, scope)]
+ when Prism::CapturePatternNode
+ capture_type = class_or_value_to_instance evaluate_match_pattern(value, pattern.value, scope)
+ value = capture_type unless capture_type.types.empty? || capture_type.types == [Types::OBJECT]
+ evaluate_match_pattern value, pattern.target, scope
+ when Prism::SplatNode
+ value = Types.array_of value
+ evaluate_match_pattern value, pattern.expression, scope if pattern.expression
+ value
+ else
+ # literal node
+ type = evaluate(pattern, scope)
+ class_or_value_to_instance(type)
+ end
+ end
+
+ def class_or_value_to_instance(type)
+ instance_types = type.types.map do |t|
+ t.is_a?(Types::SingletonType) ? Types::InstanceType.new(t.module_or_class) : t
+ end
+ Types::UnionType[*instance_types]
+ end
+
+ def evaluate_write(node, value, scope, evaluated_receivers)
+ case node
+ when Prism::MultiTargetNode
+ evaluate_multi_write node, value, scope, evaluated_receivers
+ when Prism::CallNode
+ evaluated_receivers&.[](node.receiver) || evaluate(node.receiver, scope) if node.receiver
+ when Prism::SplatNode
+ evaluate_write node.expression, Types.array_of(value), scope, evaluated_receivers if node.expression
+ when Prism::LocalVariableTargetNode, Prism::GlobalVariableTargetNode, Prism::InstanceVariableTargetNode, Prism::ClassVariableTargetNode, Prism::ConstantTargetNode
+ scope[node.name.to_s] = value
+ when Prism::ConstantPathTargetNode
+ receiver = evaluated_receivers&.[](node.parent) || evaluate(node.parent, scope) if node.parent
+ const_path_write receiver, node.child.name.to_s, value, scope
+ value
+ end
+ end
+
+ def evaluate_multi_write(node, values, scope, evaluated_receivers)
+ pre_targets = node.lefts
+ splat_target = node.rest
+ post_targets = node.rights
+ size = pre_targets.size + (splat_target ? 1 : 0) + post_targets.size
+ values = values.is_a?(Array) ? values.dup : sized_splat(values, :to_ary, size)
+ pre_pairs = pre_targets.zip(values.shift(pre_targets.size))
+ post_pairs = post_targets.zip(values.pop(post_targets.size))
+ splat_pairs = splat_target ? [[splat_target, Types::UnionType[*values]]] : []
+ (pre_pairs + splat_pairs + post_pairs).each do |target, value|
+ evaluate_write target, value || Types::NIL, scope, evaluated_receivers
+ end
+ end
+
+ def evaluate_multi_write_receiver(node, scope, evaluated_receivers)
+ case node
+ when Prism::MultiWriteNode, Prism::MultiTargetNode
+ targets = [*node.lefts, *node.rest, *node.rights]
+ targets.each { evaluate_multi_write_receiver _1, scope, evaluated_receivers }
+ when Prism::CallNode
+ if node.receiver
+ receiver = evaluate(node.receiver, scope)
+ evaluated_receivers[node.receiver] = receiver if evaluated_receivers
+ end
+ if node.arguments
+ node.arguments.arguments&.each do |arg|
+ if arg.is_a? Prism::SplatNode
+ evaluate arg.expression, scope
+ else
+ evaluate arg, scope
+ end
+ end
+ end
+ when Prism::SplatNode
+ evaluate_multi_write_receiver node.expression, scope, evaluated_receivers if node.expression
+ end
+ end
+
+ def evaluate_list_splat_items(list, scope)
+ items = list.flat_map do |node|
+ if node.is_a? Prism::SplatNode
+ next unless node.expression # def f(*); [*]
+
+ splat = evaluate node.expression, scope
+ array_elem, non_array = partition_to_array splat.nonnillable, :to_a
+ [*array_elem, *non_array]
+ else
+ evaluate node, scope
+ end
+ end.compact.uniq
+ Types::UnionType[*items]
+ end
+
+ def sized_splat(value, method, size)
+ array_elem, non_array = partition_to_array value, method
+ values = [Types::UnionType[*array_elem, *non_array]]
+ values += [array_elem] * (size - 1) if array_elem && size >= 1
+ values
+ end
+
+ def partition_to_array(value, method)
+ arrays, non_arrays = value.types.partition { _1.is_a?(Types::InstanceType) && _1.klass == Array }
+ non_arrays.select! do |type|
+ to_array_result = method_call type, method, [], nil, nil, nil, name_match: false
+ if to_array_result.is_a?(Types::InstanceType) && to_array_result.klass == Array
+ arrays << to_array_result
+ false
+ else
+ true
+ end
+ end
+ array_elem = arrays.empty? ? nil : Types::UnionType[*arrays.map { _1.params[:Elem] || Types::OBJECT }]
+ non_array = non_arrays.empty? ? nil : Types::UnionType[*non_arrays]
+ [array_elem, non_array]
+ end
+
+ def method_call(receiver, method_name, args, kwargs, block, scope, name_match: true)
+ methods = Types.rbs_methods receiver, method_name.to_sym, args, kwargs, !!block
+ block_called = false
+ type_breaks = methods.map do |method, given_params, method_params|
+ receiver_vars = receiver.is_a?(Types::InstanceType) ? receiver.params : {}
+ free_vars = method.type.free_variables - receiver_vars.keys.to_set
+ vars = receiver_vars.merge Types.match_free_variables(free_vars, method_params, given_params)
+ if block && method.block
+ params_type = method.block.type.required_positionals.map do |func_param|
+ Types.from_rbs_type func_param.type, receiver, vars
+ end
+ self_type = Types.from_rbs_type method.block.self_type, receiver, vars if method.block.self_type
+ block_response, breaks = block.call params_type, self_type
+ block_called = true
+ vars.merge! Types.match_free_variables(free_vars - vars.keys.to_set, [method.block.type.return_type], [block_response])
+ end
+ if Types.method_return_bottom?(method)
+ [nil, breaks]
+ else
+ [Types.from_rbs_type(method.type.return_type, receiver, vars || {}), breaks]
+ end
+ end
+ block&.call [], nil unless block_called
+ terminates = !type_breaks.empty? && type_breaks.map(&:first).all?(&:nil?)
+ types = type_breaks.map(&:first).compact
+ breaks = type_breaks.map(&:last).compact
+ types << OBJECT_METHODS[method_name.to_sym] if name_match && OBJECT_METHODS.has_key?(method_name.to_sym)
+
+ if method_name.to_sym == :new
+ receiver.types.each do |type|
+ if type.is_a?(Types::SingletonType) && type.module_or_class.is_a?(Class)
+ types << Types::InstanceType.new(type.module_or_class)
+ end
+ end
+ end
+ scope&.terminate if terminates && breaks.empty?
+ Types::UnionType[*types, *breaks]
+ end
+
+ def self.calculate_target_type_scope(binding, parents, target)
+ dig_targets = DigTarget.new(parents, target) do |type, scope|
+ return type, scope
+ end
+ program = parents.first
+ scope = Scope.from_binding(binding, program.locals)
+ new(dig_targets).evaluate program, scope
+ [Types::NIL, scope]
+ end
+ end
+ end
+end
diff --git a/lib/irb/type_completion/types.rb b/lib/irb/type_completion/types.rb
new file mode 100644
index 0000000000..f0f2342ffe
--- /dev/null
+++ b/lib/irb/type_completion/types.rb
@@ -0,0 +1,426 @@
+# frozen_string_literal: true
+
+require_relative 'methods'
+
+module IRB
+ module TypeCompletion
+ module Types
+ OBJECT_TO_TYPE_SAMPLE_SIZE = 50
+
+ singleton_class.attr_reader :rbs_builder, :rbs_load_error
+
+ def self.preload_in_thread
+ return if @preload_started
+
+ @preload_started = true
+ Thread.new do
+ load_rbs_builder
+ end
+ end
+
+ def self.load_rbs_builder
+ require 'rbs'
+ require 'rbs/cli'
+ loader = RBS::CLI::LibraryOptions.new.loader
+ loader.add path: Pathname('sig')
+ @rbs_builder = RBS::DefinitionBuilder.new env: RBS::Environment.from_loader(loader).resolve_type_names
+ rescue LoadError, StandardError => e
+ @rbs_load_error = e
+ nil
+ end
+
+ def self.class_name_of(klass)
+ klass = klass.superclass if klass.singleton_class?
+ Methods::MODULE_NAME_METHOD.bind_call klass
+ end
+
+ def self.rbs_search_method(klass, method_name, singleton)
+ klass.ancestors.each do |ancestor|
+ name = class_name_of ancestor
+ next unless name && rbs_builder
+ type_name = RBS::TypeName(name).absolute!
+ definition = (singleton ? rbs_builder.build_singleton(type_name) : rbs_builder.build_instance(type_name)) rescue nil
+ method = definition.methods[method_name] if definition
+ return method if method
+ end
+ nil
+ end
+
+ def self.method_return_type(type, method_name)
+ receivers = type.types.map do |t|
+ case t
+ in SingletonType
+ [t, t.module_or_class, true]
+ in InstanceType
+ [t, t.klass, false]
+ end
+ end
+ types = receivers.flat_map do |receiver_type, klass, singleton|
+ method = rbs_search_method klass, method_name, singleton
+ next [] unless method
+ method.method_types.map do |method|
+ from_rbs_type(method.type.return_type, receiver_type, {})
+ end
+ end
+ UnionType[*types]
+ end
+
+ def self.rbs_methods(type, method_name, args_types, kwargs_type, has_block)
+ return [] unless rbs_builder
+
+ receivers = type.types.map do |t|
+ case t
+ in SingletonType
+ [t, t.module_or_class, true]
+ in InstanceType
+ [t, t.klass, false]
+ end
+ end
+ has_splat = args_types.include?(nil)
+ methods_with_score = receivers.flat_map do |receiver_type, klass, singleton|
+ method = rbs_search_method klass, method_name, singleton
+ next [] unless method
+ method.method_types.map do |method_type|
+ score = 0
+ score += 2 if !!method_type.block == has_block
+ reqs = method_type.type.required_positionals
+ opts = method_type.type.optional_positionals
+ rest = method_type.type.rest_positionals
+ trailings = method_type.type.trailing_positionals
+ keyreqs = method_type.type.required_keywords
+ keyopts = method_type.type.optional_keywords
+ keyrest = method_type.type.rest_keywords
+ args = args_types
+ if kwargs_type&.any? && keyreqs.empty? && keyopts.empty? && keyrest.nil?
+ kw_value_type = UnionType[*kwargs_type.values]
+ args += [InstanceType.new(Hash, K: SYMBOL, V: kw_value_type)]
+ end
+ if has_splat
+ score += 1 if args.count(&:itself) <= reqs.size + opts.size + trailings.size
+ elsif reqs.size + trailings.size <= args.size && (rest || args.size <= reqs.size + opts.size + trailings.size)
+ score += 2
+ centers = args[reqs.size...-trailings.size]
+ given = args.first(reqs.size) + centers.take(opts.size) + args.last(trailings.size)
+ expected = (reqs + opts.take(centers.size) + trailings).map(&:type)
+ if rest
+ given << UnionType[*centers.drop(opts.size)]
+ expected << rest.type
+ end
+ if given.any?
+ score += given.zip(expected).count do |t, e|
+ e = from_rbs_type e, receiver_type
+ intersect?(t, e) || (intersect?(STRING, e) && t.methods.include?(:to_str)) || (intersect?(INTEGER, e) && t.methods.include?(:to_int)) || (intersect?(ARRAY, e) && t.methods.include?(:to_ary))
+ end.fdiv(given.size)
+ end
+ end
+ [[method_type, given || [], expected || []], score]
+ end
+ end
+ max_score = methods_with_score.map(&:last).max
+ methods_with_score.select { _2 == max_score }.map(&:first)
+ end
+
+ def self.intersect?(a, b)
+ atypes = a.types.group_by(&:class)
+ btypes = b.types.group_by(&:class)
+ if atypes[SingletonType] && btypes[SingletonType]
+ aa, bb = [atypes, btypes].map {|types| types[SingletonType].map(&:module_or_class) }
+ return true if (aa & bb).any?
+ end
+
+ aa, bb = [atypes, btypes].map {|types| (types[InstanceType] || []).map(&:klass) }
+ (aa.flat_map(&:ancestors) & bb).any?
+ end
+
+ def self.type_from_object(object)
+ case object
+ when Array
+ InstanceType.new Array, { Elem: union_type_from_objects(object) }
+ when Hash
+ InstanceType.new Hash, { K: union_type_from_objects(object.keys), V: union_type_from_objects(object.values) }
+ when Module
+ SingletonType.new object
+ else
+ klass = Methods::OBJECT_SINGLETON_CLASS_METHOD.bind_call(object) rescue Methods::OBJECT_CLASS_METHOD.bind_call(object)
+ InstanceType.new klass
+ end
+ end
+
+ def self.union_type_from_objects(objects)
+ values = objects.size <= OBJECT_TO_TYPE_SAMPLE_SIZE ? objects : objects.sample(OBJECT_TO_TYPE_SAMPLE_SIZE)
+ klasses = values.map { Methods::OBJECT_CLASS_METHOD.bind_call(_1) }
+ UnionType[*klasses.uniq.map { InstanceType.new _1 }]
+ end
+
+ class SingletonType
+ attr_reader :module_or_class
+ def initialize(module_or_class)
+ @module_or_class = module_or_class
+ end
+ def transform() = yield(self)
+ def methods() = @module_or_class.methods
+ def all_methods() = methods | Kernel.methods
+ def constants() = @module_or_class.constants
+ def types() = [self]
+ def nillable?() = false
+ def nonnillable() = self
+ def inspect
+ "#{module_or_class}.itself"
+ end
+ end
+
+ class InstanceType
+ attr_reader :klass, :params
+ def initialize(klass, params = {})
+ @klass = klass
+ @params = params
+ end
+ def transform() = yield(self)
+ def methods() = rbs_methods.select { _2.public? }.keys | @klass.instance_methods
+ def all_methods() = rbs_methods.keys | @klass.instance_methods | @klass.private_instance_methods
+ def constants() = []
+ def types() = [self]
+ def nillable?() = (@klass == NilClass)
+ def nonnillable() = self
+ def rbs_methods
+ name = Types.class_name_of(@klass)
+ return {} unless name && Types.rbs_builder
+
+ type_name = RBS::TypeName(name).absolute!
+ Types.rbs_builder.build_instance(type_name).methods rescue {}
+ end
+ def inspect
+ if params.empty?
+ inspect_without_params
+ else
+ params_string = "[#{params.map { "#{_1}: #{_2.inspect}" }.join(', ')}]"
+ "#{inspect_without_params}#{params_string}"
+ end
+ end
+ def inspect_without_params
+ if klass == NilClass
+ 'nil'
+ elsif klass == TrueClass
+ 'true'
+ elsif klass == FalseClass
+ 'false'
+ else
+ klass.singleton_class? ? klass.superclass.to_s : klass.to_s
+ end
+ end
+ end
+
+ NIL = InstanceType.new NilClass
+ OBJECT = InstanceType.new Object
+ TRUE = InstanceType.new TrueClass
+ FALSE = InstanceType.new FalseClass
+ SYMBOL = InstanceType.new Symbol
+ STRING = InstanceType.new String
+ INTEGER = InstanceType.new Integer
+ RANGE = InstanceType.new Range
+ REGEXP = InstanceType.new Regexp
+ FLOAT = InstanceType.new Float
+ RATIONAL = InstanceType.new Rational
+ COMPLEX = InstanceType.new Complex
+ ARRAY = InstanceType.new Array
+ HASH = InstanceType.new Hash
+ CLASS = InstanceType.new Class
+ MODULE = InstanceType.new Module
+ PROC = InstanceType.new Proc
+
+ class UnionType
+ attr_reader :types
+
+ def initialize(*types)
+ @types = []
+ singletons = []
+ instances = {}
+ collect = -> type do
+ case type
+ in UnionType
+ type.types.each(&collect)
+ in InstanceType
+ params = (instances[type.klass] ||= {})
+ type.params.each do |k, v|
+ (params[k] ||= []) << v
+ end
+ in SingletonType
+ singletons << type
+ end
+ end
+ types.each(&collect)
+ @types = singletons.uniq + instances.map do |klass, params|
+ InstanceType.new(klass, params.transform_values { |v| UnionType[*v] })
+ end
+ end
+
+ def transform(&block)
+ UnionType[*types.map(&block)]
+ end
+
+ def nillable?
+ types.any?(&:nillable?)
+ end
+
+ def nonnillable
+ UnionType[*types.reject { _1.is_a?(InstanceType) && _1.klass == NilClass }]
+ end
+
+ def self.[](*types)
+ type = new(*types)
+ if type.types.empty?
+ OBJECT
+ elsif type.types.size == 1
+ type.types.first
+ else
+ type
+ end
+ end
+
+ def methods() = @types.flat_map(&:methods).uniq
+ def all_methods() = @types.flat_map(&:all_methods).uniq
+ def constants() = @types.flat_map(&:constants).uniq
+ def inspect() = @types.map(&:inspect).join(' | ')
+ end
+
+ BOOLEAN = UnionType[TRUE, FALSE]
+
+ def self.array_of(*types)
+ type = types.size >= 2 ? UnionType[*types] : types.first || OBJECT
+ InstanceType.new Array, Elem: type
+ end
+
+ def self.from_rbs_type(return_type, self_type, extra_vars = {})
+ case return_type
+ when RBS::Types::Bases::Self
+ self_type
+ when RBS::Types::Bases::Bottom, RBS::Types::Bases::Nil
+ NIL
+ when RBS::Types::Bases::Any, RBS::Types::Bases::Void
+ OBJECT
+ when RBS::Types::Bases::Class
+ self_type.transform do |type|
+ case type
+ in SingletonType
+ InstanceType.new(self_type.module_or_class.is_a?(Class) ? Class : Module)
+ in InstanceType
+ SingletonType.new type.klass
+ end
+ end
+ UnionType[*types]
+ when RBS::Types::Bases::Bool
+ BOOLEAN
+ when RBS::Types::Bases::Instance
+ self_type.transform do |type|
+ if type.is_a?(SingletonType) && type.module_or_class.is_a?(Class)
+ InstanceType.new type.module_or_class
+ else
+ OBJECT
+ end
+ end
+ when RBS::Types::Union
+ UnionType[*return_type.types.map { from_rbs_type _1, self_type, extra_vars }]
+ when RBS::Types::Proc
+ PROC
+ when RBS::Types::Tuple
+ elem = UnionType[*return_type.types.map { from_rbs_type _1, self_type, extra_vars }]
+ InstanceType.new Array, Elem: elem
+ when RBS::Types::Record
+ InstanceType.new Hash, K: SYMBOL, V: OBJECT
+ when RBS::Types::Literal
+ InstanceType.new return_type.literal.class
+ when RBS::Types::Variable
+ if extra_vars.key? return_type.name
+ extra_vars[return_type.name]
+ elsif self_type.is_a? InstanceType
+ self_type.params[return_type.name] || OBJECT
+ elsif self_type.is_a? UnionType
+ types = self_type.types.filter_map do |t|
+ t.params[return_type.name] if t.is_a? InstanceType
+ end
+ UnionType[*types]
+ else
+ OBJECT
+ end
+ when RBS::Types::Optional
+ UnionType[from_rbs_type(return_type.type, self_type, extra_vars), NIL]
+ when RBS::Types::Alias
+ case return_type.name.name
+ when :int
+ INTEGER
+ when :boolish
+ BOOLEAN
+ when :string
+ STRING
+ else
+ # TODO: ???
+ OBJECT
+ end
+ when RBS::Types::Interface
+ # unimplemented
+ OBJECT
+ when RBS::Types::ClassInstance
+ klass = return_type.name.to_namespace.path.reduce(Object) { _1.const_get _2 }
+ if return_type.args
+ args = return_type.args.map { from_rbs_type _1, self_type, extra_vars }
+ names = rbs_builder.build_singleton(return_type.name).type_params
+ params = names.map.with_index { [_1, args[_2] || OBJECT] }.to_h
+ end
+ InstanceType.new klass, params || {}
+ end
+ end
+
+ def self.method_return_bottom?(method)
+ method.type.return_type.is_a? RBS::Types::Bases::Bottom
+ end
+
+ def self.match_free_variables(vars, types, values)
+ accumulator = {}
+ types.zip values do |t, v|
+ _match_free_variable(vars, t, v, accumulator) if v
+ end
+ accumulator.transform_values { UnionType[*_1] }
+ end
+
+ def self._match_free_variable(vars, rbs_type, value, accumulator)
+ case [rbs_type, value]
+ in [RBS::Types::Variable,]
+ (accumulator[rbs_type.name] ||= []) << value if vars.include? rbs_type.name
+ in [RBS::Types::ClassInstance, InstanceType]
+ names = rbs_builder.build_singleton(rbs_type.name).type_params
+ names.zip(rbs_type.args).each do |name, arg|
+ v = value.params[name]
+ _match_free_variable vars, arg, v, accumulator if v
+ end
+ in [RBS::Types::Tuple, InstanceType] if value.klass == Array
+ v = value.params[:Elem]
+ rbs_type.types.each do |t|
+ _match_free_variable vars, t, v, accumulator
+ end
+ in [RBS::Types::Record, InstanceType] if value.klass == Hash
+ # TODO
+ in [RBS::Types::Interface,]
+ definition = rbs_builder.build_interface rbs_type.name
+ convert = {}
+ definition.type_params.zip(rbs_type.args).each do |from, arg|
+ convert[from] = arg.name if arg.is_a? RBS::Types::Variable
+ end
+ return if convert.empty?
+ ac = {}
+ definition.methods.each do |method_name, method|
+ return_type = method_return_type value, method_name
+ method.defs.each do |method_def|
+ interface_return_type = method_def.type.type.return_type
+ _match_free_variable convert, interface_return_type, return_type, ac
+ end
+ end
+ convert.each do |from, to|
+ values = ac[from]
+ (accumulator[to] ||= []).concat values if values
+ end
+ else
+ end
+ end
+ end
+ end
+end
diff --git a/test/irb/test_cmd.rb b/test/irb/test_cmd.rb
index 67dcfd0a63..219710c921 100644
--- a/test/irb/test_cmd.rb
+++ b/test/irb/test_cmd.rb
@@ -90,6 +90,7 @@ module TestIRB
Ruby\sversion:\s.+\n
IRB\sversion:\sirb\s.+\n
InputMethod:\sAbstract\sInputMethod\n
+ Completion: .+\n
\.irbrc\spath:\s.+\n
RUBY_PLATFORM:\s.+\n
East\sAsian\sAmbiguous\sWidth:\s\d\n
@@ -113,6 +114,7 @@ module TestIRB
Ruby\sversion:\s.+\n
IRB\sversion:\sirb\s.+\n
InputMethod:\sAbstract\sInputMethod\n
+ Completion: .+\n
\.irbrc\spath:\s.+\n
RUBY_PLATFORM:\s.+\n
East\sAsian\sAmbiguous\sWidth:\s\d\n
@@ -139,6 +141,7 @@ module TestIRB
Ruby\sversion:\s.+\n
IRB\sversion:\sirb\s.+\n
InputMethod:\sAbstract\sInputMethod\n
+ Completion: .+\n
RUBY_PLATFORM:\s.+\n
East\sAsian\sAmbiguous\sWidth:\s\d\n
#{@is_win ? 'Code\spage:\s\d+\n' : ''}
@@ -168,6 +171,7 @@ module TestIRB
Ruby\sversion:\s.+\n
IRB\sversion:\sirb\s.+\n
InputMethod:\sAbstract\sInputMethod\n
+ Completion: .+\n
RUBY_PLATFORM:\s.+\n
East\sAsian\sAmbiguous\sWidth:\s\d\n
#{@is_win ? 'Code\spage:\s\d+\n' : ''}
@@ -196,6 +200,7 @@ module TestIRB
Ruby\sversion: .+\n
IRB\sversion:\sirb .+\n
InputMethod:\sAbstract\sInputMethod\n
+ Completion: .+\n
\.irbrc\spath: .+\n
RUBY_PLATFORM: .+\n
LANG\senv:\sja_JP\.UTF-8\n
diff --git a/test/irb/test_context.rb b/test/irb/test_context.rb
index af47bec9de..ce57df6cdb 100644
--- a/test/irb/test_context.rb
+++ b/test/irb/test_context.rb
@@ -652,6 +652,24 @@ module TestIRB
], out)
end
+ def test_build_completor
+ verbose, $VERBOSE = $VERBOSE, nil
+ original_completor = IRB.conf[:COMPLETOR]
+ IRB.conf[:COMPLETOR] = :regexp
+ assert_equal 'IRB::RegexpCompletor', @context.send(:build_completor).class.name
+ IRB.conf[:COMPLETOR] = :type
+ if RUBY_VERSION >= '3.0.0' && RUBY_ENGINE != 'truffleruby'
+ assert_equal 'IRB::TypeCompletion::Completor', @context.send(:build_completor).class.name
+ else
+ assert_equal 'IRB::RegexpCompletor', @context.send(:build_completor).class.name
+ end
+ IRB.conf[:COMPLETOR] = :unknown
+ assert_equal 'IRB::RegexpCompletor', @context.send(:build_completor).class.name
+ ensure
+ $VERBOSE = verbose
+ IRB.conf[:COMPLETOR] = original_completor
+ end
+
private
def without_colorize
diff --git a/test/irb/test_input_method.rb b/test/irb/test_input_method.rb
index 2d8cfadcf5..e6a1b06e82 100644
--- a/test/irb/test_input_method.rb
+++ b/test/irb/test_input_method.rb
@@ -24,7 +24,7 @@ module TestIRB
def test_initialization
Reline.completion_proc = nil
Reline.dig_perfect_match_proc = nil
- IRB::RelineInputMethod.new
+ IRB::RelineInputMethod.new(IRB::RegexpCompletor.new)
assert_nil Reline.completion_append_character
assert_equal '', Reline.completer_quote_characters
@@ -40,7 +40,7 @@ module TestIRB
IRB.conf[:USE_AUTOCOMPLETE] = false
- IRB::RelineInputMethod.new
+ IRB::RelineInputMethod.new(IRB::RegexpCompletor.new)
refute Reline.autocompletion
assert_equal empty_proc, Reline.dialog_proc(:show_doc).dialog_proc
@@ -55,7 +55,7 @@ module TestIRB
IRB.conf[:USE_AUTOCOMPLETE] = true
- IRB::RelineInputMethod.new
+ IRB::RelineInputMethod.new(IRB::RegexpCompletor.new)
assert Reline.autocompletion
assert_not_equal empty_proc, Reline.dialog_proc(:show_doc).dialog_proc
@@ -71,7 +71,7 @@ module TestIRB
IRB.conf[:USE_AUTOCOMPLETE] = true
without_rdoc do
- IRB::RelineInputMethod.new
+ IRB::RelineInputMethod.new(IRB::RegexpCompletor.new)
end
assert Reline.autocompletion
@@ -89,7 +89,7 @@ module TestIRB
end
def display_document(target, bind)
- input_method = IRB::RelineInputMethod.new
+ input_method = IRB::RelineInputMethod.new(IRB::RegexpCompletor.new)
input_method.instance_variable_set(:@completion_params, [target, '', '', bind])
input_method.display_document(target, driver: @driver)
end
diff --git a/test/irb/type_completion/test_scope.rb b/test/irb/type_completion/test_scope.rb
new file mode 100644
index 0000000000..d7f9540b06
--- /dev/null
+++ b/test/irb/type_completion/test_scope.rb
@@ -0,0 +1,112 @@
+# frozen_string_literal: true
+
+return unless RUBY_VERSION >= '3.0.0'
+return if RUBY_ENGINE == 'truffleruby' # needs endless method definition
+
+require 'irb/type_completion/scope'
+require_relative '../helper'
+
+module TestIRB
+ class TypeCompletionScopeTest < TestCase
+ A, B, C, D, E, F, G, H, I, J, K = ('A'..'K').map do |name|
+ klass = Class.new
+ klass.define_singleton_method(:inspect) { name }
+ IRB::TypeCompletion::Types::InstanceType.new(klass)
+ end
+
+ def assert_type(expected_types, type)
+ assert_equal [*expected_types].map(&:klass).to_set, type.types.map(&:klass).to_set
+ end
+
+ def table(*local_variable_names)
+ local_variable_names.to_h { [_1, IRB::TypeCompletion::Types::NIL] }
+ end
+
+ def base_scope
+ IRB::TypeCompletion::RootScope.new(binding, Object.new, [])
+ end
+
+ def test_lvar
+ scope = IRB::TypeCompletion::Scope.new base_scope, table('a')
+ scope['a'] = A
+ assert_equal A, scope['a']
+ end
+
+ def test_conditional
+ scope = IRB::TypeCompletion::Scope.new base_scope, table('a')
+ scope.conditional do |sub_scope|
+ sub_scope['a'] = A
+ end
+ assert_type [A, IRB::TypeCompletion::Types::NIL], scope['a']
+ end
+
+ def test_branch
+ scope = IRB::TypeCompletion::Scope.new base_scope, table('a', 'b', 'c', 'd')
+ scope['c'] = A
+ scope['d'] = B
+ scope.run_branches(
+ -> { _1['a'] = _1['c'] = _1['d'] = C },
+ -> { _1['a'] = _1['b'] = _1['d'] = D },
+ -> { _1['a'] = _1['b'] = _1['d'] = E },
+ -> { _1['a'] = _1['b'] = _1['c'] = F; _1.terminate }
+ )
+ assert_type [C, D, E], scope['a']
+ assert_type [IRB::TypeCompletion::Types::NIL, D, E], scope['b']
+ assert_type [A, C], scope['c']
+ assert_type [C, D, E], scope['d']
+ end
+
+ def test_scope_local_variables
+ scope1 = IRB::TypeCompletion::Scope.new base_scope, table('a', 'b')
+ scope2 = IRB::TypeCompletion::Scope.new scope1, table('b', 'c'), trace_lvar: false
+ scope3 = IRB::TypeCompletion::Scope.new scope2, table('c', 'd')
+ scope4 = IRB::TypeCompletion::Scope.new scope2, table('d', 'e')
+ assert_empty base_scope.local_variables
+ assert_equal %w[a b], scope1.local_variables.sort
+ assert_equal %w[b c], scope2.local_variables.sort
+ assert_equal %w[b c d], scope3.local_variables.sort
+ assert_equal %w[b c d e], scope4.local_variables.sort
+ end
+
+ def test_nested_scope
+ scope = IRB::TypeCompletion::Scope.new base_scope, table('a', 'b', 'c')
+ scope['a'] = A
+ scope['b'] = A
+ scope['c'] = A
+ sub_scope = IRB::TypeCompletion::Scope.new scope, { 'c' => B }
+ assert_type A, sub_scope['a']
+
+ assert_type A, sub_scope['b']
+ assert_type B, sub_scope['c']
+ sub_scope['a'] = C
+ sub_scope.conditional { _1['b'] = C }
+ sub_scope['c'] = C
+ assert_type C, sub_scope['a']
+ assert_type [A, C], sub_scope['b']
+ assert_type C, sub_scope['c']
+ scope.update sub_scope
+ assert_type C, scope['a']
+ assert_type [A, C], scope['b']
+ assert_type A, scope['c']
+ end
+
+ def test_break
+ scope = IRB::TypeCompletion::Scope.new base_scope, table('a')
+ scope['a'] = A
+ breakable_scope = IRB::TypeCompletion::Scope.new scope, { IRB::TypeCompletion::Scope::BREAK_RESULT => nil }
+ breakable_scope.conditional do |sub|
+ sub['a'] = B
+ assert_type [B], sub['a']
+ sub.terminate_with IRB::TypeCompletion::Scope::BREAK_RESULT, C
+ sub['a'] = C
+ assert_type [C], sub['a']
+ end
+ assert_type [A], breakable_scope['a']
+ breakable_scope[IRB::TypeCompletion::Scope::BREAK_RESULT] = D
+ breakable_scope.merge_jumps
+ assert_type [C, D], breakable_scope[IRB::TypeCompletion::Scope::BREAK_RESULT]
+ scope.update breakable_scope
+ assert_type [A, B], scope['a']
+ end
+ end
+end
diff --git a/test/irb/type_completion/test_type_analyze.rb b/test/irb/type_completion/test_type_analyze.rb
new file mode 100644
index 0000000000..c417a8ad12
--- /dev/null
+++ b/test/irb/type_completion/test_type_analyze.rb
@@ -0,0 +1,697 @@
+# frozen_string_literal: true
+
+# Run test only when Ruby >= 3.0 and %w[prism rbs] are available
+return unless RUBY_VERSION >= '3.0.0'
+return if RUBY_ENGINE == 'truffleruby' # needs endless method definition
+begin
+ require 'prism'
+ require 'rbs'
+rescue LoadError
+ return
+end
+
+
+require 'irb/completion'
+require 'irb/type_completion/completor'
+require_relative '../helper'
+
+module TestIRB
+ class TypeCompletionAnalyzeTest < TestCase
+ def setup
+ IRB::TypeCompletion::Types.load_rbs_builder unless IRB::TypeCompletion::Types.rbs_builder
+ end
+
+ def empty_binding
+ binding
+ end
+
+ def analyze(code, binding: nil)
+ completor = IRB::TypeCompletion::Completor.new
+ def completor.handle_error(e)
+ raise e
+ end
+ completor.analyze(code, binding || empty_binding)
+ end
+
+ def assert_analyze_type(code, type, token = nil, binding: empty_binding)
+ result_type, result_token = analyze(code, binding: binding)
+ assert_equal type, result_type
+ assert_equal token, result_token if token
+ end
+
+ def assert_call(code, include: nil, exclude: nil, binding: nil)
+ raise ArgumentError if include.nil? && exclude.nil?
+
+ result = analyze(code.strip, binding: binding)
+ type = result[1] if result[0] == :call
+ klasses = type.types.flat_map do
+ _1.klass.singleton_class? ? [_1.klass.superclass, _1.klass] : _1.klass
+ end
+ assert ([*include] - klasses).empty?, "Expected #{klasses} to include #{include}" if include
+ assert (klasses & [*exclude]).empty?, "Expected #{klasses} not to include #{exclude}" if exclude
+ end
+
+ def test_lvar_ivar_gvar_cvar
+ assert_analyze_type('puts(x', :lvar_or_method, 'x')
+ assert_analyze_type('puts($', :gvar, '$')
+ assert_analyze_type('puts($x', :gvar, '$x')
+ assert_analyze_type('puts(@', :ivar, '@')
+ assert_analyze_type('puts(@x', :ivar, '@x')
+ assert_analyze_type('puts(@@', :cvar, '@@')
+ assert_analyze_type('puts(@@x', :cvar, '@@x')
+ end
+
+ def test_rescue
+ assert_call '(1 rescue 1.0).', include: [Integer, Float]
+ assert_call 'a=""; (a=1) rescue (a=1.0); a.', include: [Integer, Float], exclude: String
+ assert_call 'begin; 1; rescue; 1.0; end.', include: [Integer, Float]
+ assert_call 'begin; 1; rescue A; 1.0; rescue B; 1i; end.', include: [Integer, Float, Complex]
+ assert_call 'begin; 1i; rescue; 1.0; else; 1; end.', include: [Integer, Float], exclude: Complex
+ assert_call 'begin; 1; rescue; 1.0; ensure; 1i; end.', include: [Integer, Float], exclude: Complex
+ assert_call 'begin; 1i; rescue; 1.0; else; 1; ensure; 1i; end.', include: [Integer, Float], exclude: Complex
+ assert_call 'a=""; begin; a=1; rescue; a=1.0; end; a.', include: [Integer, Float], exclude: [String]
+ assert_call 'a=""; begin; a=1; rescue; a=1.0; else; a=1r; end; a.', include: [Float, Rational], exclude: [String, Integer]
+ assert_call 'a=""; begin; a=1; rescue; a=1.0; else; a=1r; ensure; a = 1i; end; a.', include: Complex, exclude: [Float, Rational, String, Integer]
+ end
+
+ def test_rescue_assign
+ assert_equal [:lvar_or_method, 'a'], analyze('begin; rescue => a')[0, 2]
+ assert_equal [:gvar, '$a'], analyze('begin; rescue => $a')[0, 2]
+ assert_equal [:ivar, '@a'], analyze('begin; rescue => @a')[0, 2]
+ assert_equal [:cvar, '@@a'], analyze('begin; rescue => @@a')[0, 2]
+ assert_equal [:const, 'A'], analyze('begin; rescue => A').values_at(0, 2)
+ assert_equal [:call, 'b'], analyze('begin; rescue => a.b').values_at(0, 2)
+ end
+
+ def test_ref
+ bind = eval <<~RUBY
+ class (Module.new)::A
+ @ivar = :a
+ @@cvar = 'a'
+ binding
+ end
+ RUBY
+ assert_call('STDIN.', include: STDIN.singleton_class)
+ assert_call('$stdin.', include: $stdin.singleton_class)
+ assert_call('@ivar.', include: Symbol, binding: bind)
+ assert_call('@@cvar.', include: String, binding: bind)
+ lbind = eval('lvar = 1; binding')
+ assert_call('lvar.', include: Integer, binding: lbind)
+ end
+
+ def test_self_ivar_ref
+ obj = Object.new
+ obj.instance_variable_set(:@hoge, 1)
+ assert_call('obj.instance_eval { @hoge.', include: Integer, binding: obj.instance_eval { binding })
+ if Class.method_defined? :attached_object
+ bind = binding
+ assert_call('obj.instance_eval { @hoge.', include: Integer, binding: bind)
+ assert_call('@hoge = 1.0; obj.instance_eval { @hoge.', include: Integer, exclude: Float, binding: bind)
+ assert_call('@hoge = 1.0; obj.instance_eval { @hoge = "" }; @hoge.', include: Float, exclude: [Integer, String], binding: bind)
+ assert_call('@fuga = 1.0; obj.instance_eval { @fuga.', exclude: Float, binding: bind)
+ assert_call('@fuga = 1.0; obj.instance_eval { @fuga = "" }; @fuga.', include: Float, exclude: [Integer, String], binding: bind)
+ end
+ end
+
+ class CVarModule
+ @@test_cvar = 1
+ end
+ def test_module_cvar_ref
+ bind = binding
+ assert_call('@@foo=1; class A; @@foo.', exclude: Integer, binding: bind)
+ assert_call('@@foo=1; class A; @@foo=1.0; @@foo.', include: Float, exclude: Integer, binding: bind)
+ assert_call('@@foo=1; class A; @@foo=1.0; end; @@foo.', include: Integer, exclude: Float, binding: bind)
+ assert_call('module CVarModule; @@test_cvar.', include: Integer, binding: bind)
+ assert_call('class Array; @@foo = 1; end; class Array; @@foo.', include: Integer, binding: bind)
+ assert_call('class Array; class B; @@foo = 1; end; class B; @@foo.', include: Integer, binding: bind)
+ assert_call('class Array; class B; @@foo = 1; end; @@foo.', exclude: Integer, binding: bind)
+ end
+
+ def test_lvar_singleton_method
+ a = 1
+ b = +''
+ c = Object.new
+ d = [a, b, c]
+ binding = Kernel.binding
+ assert_call('a.', include: Integer, exclude: String, binding: binding)
+ assert_call('b.', include: b.singleton_class, exclude: [Integer, Object], binding: binding)
+ assert_call('c.', include: c.singleton_class, exclude: [Integer, String], binding: binding)
+ assert_call('d.', include: d.class, exclude: [Integer, String, Object], binding: binding)
+ assert_call('d.sample.', include: [Integer, String, Object], exclude: [b.singleton_class, c.singleton_class], binding: binding)
+ end
+
+ def test_local_variable_assign
+ assert_call('(a = 1).', include: Integer)
+ assert_call('a = 1; a = ""; a.', include: String, exclude: Integer)
+ assert_call('1 => a; a.', include: Integer)
+ end
+
+ def test_block_symbol
+ assert_call('[1].map(&:', include: Integer)
+ assert_call('1.to_s.tap(&:', include: String)
+ end
+
+ def test_union_splat
+ assert_call('a, = [[:a], 1, nil].sample; a.', include: [Symbol, Integer, NilClass], exclude: Object)
+ assert_call('[[:a], 1, nil].each do _2; _1.', include: [Symbol, Integer, NilClass], exclude: Object)
+ assert_call('a = [[:a], 1, nil, ("a".."b")].sample; [*a].sample.', include: [Symbol, Integer, NilClass, String], exclude: Object)
+ end
+
+ def test_range
+ assert_call('(1..2).first.', include: Integer)
+ assert_call('("a".."b").first.', include: String)
+ assert_call('(..1.to_f).first.', include: Float)
+ assert_call('(1.to_s..).first.', include: String)
+ assert_call('(1..2.0).first.', include: [Float, Integer])
+ end
+
+ def test_conditional_assign
+ assert_call('a = 1; a = "" if cond; a.', include: [String, Integer], exclude: NilClass)
+ assert_call('a = 1 if cond; a.', include: [Integer, NilClass])
+ assert_call(<<~RUBY, include: [String, Symbol], exclude: [Integer, NilClass])
+ a = 1
+ cond ? a = '' : a = :a
+ a.
+ RUBY
+ end
+
+ def test_block
+ assert_call('nil.then{1}.', include: Integer, exclude: NilClass)
+ assert_call('nil.then(&:to_s).', include: String, exclude: NilClass)
+ end
+
+ def test_block_break
+ assert_call('1.tap{}.', include: [Integer], exclude: NilClass)
+ assert_call('1.tap{break :a}.', include: [Symbol, Integer], exclude: NilClass)
+ assert_call('1.tap{break :a, :b}[0].', include: Symbol)
+ assert_call('1.tap{break :a; break "a"}.', include: [Symbol, Integer], exclude: [NilClass, String])
+ assert_call('1.tap{break :a if b}.', include: [Symbol, Integer], exclude: NilClass)
+ assert_call('1.tap{break :a; break "a" if b}.', include: [Symbol, Integer], exclude: [NilClass, String])
+ assert_call('1.tap{if cond; break :a; else; break "a"; end}.', include: [Symbol, Integer, String], exclude: NilClass)
+ end
+
+ def test_instance_eval
+ assert_call('1.instance_eval{:a.then{self.', include: Integer, exclude: Symbol)
+ assert_call('1.then{:a.instance_eval{self.', include: Symbol, exclude: Integer)
+ end
+
+ def test_block_next
+ assert_call('nil.then{1}.', include: Integer, exclude: [NilClass, Object])
+ assert_call('nil.then{next 1}.', include: Integer, exclude: [NilClass, Object])
+ assert_call('nil.then{next :a, :b}[0].', include: Symbol)
+ assert_call('nil.then{next 1; 1.0}.', include: Integer, exclude: [Float, NilClass, Object])
+ assert_call('nil.then{next 1; next 1.0}.', include: Integer, exclude: [Float, NilClass, Object])
+ assert_call('nil.then{1 if cond}.', include: [Integer, NilClass], exclude: Object)
+ assert_call('nil.then{if cond; 1; else; 1.0; end}.', include: [Integer, Float], exclude: [NilClass, Object])
+ assert_call('nil.then{next 1 if cond; 1.0}.', include: [Integer, Float], exclude: [NilClass, Object])
+ assert_call('nil.then{if cond; next 1; else; next 1.0; end; "a"}.', include: [Integer, Float], exclude: [String, NilClass, Object])
+ assert_call('nil.then{if cond; next 1; else; next 1.0; end; next "a"}.', include: [Integer, Float], exclude: [String, NilClass, Object])
+ end
+
+ def test_vars_with_branch_termination
+ assert_call('a=1; tap{break; a=//}; a.', include: Integer, exclude: Regexp)
+ assert_call('a=1; tap{a=1.0; break; a=//}; a.', include: [Integer, Float], exclude: Regexp)
+ assert_call('a=1; tap{next; a=//}; a.', include: Integer, exclude: Regexp)
+ assert_call('a=1; tap{a=1.0; next; a=//}; a.', include: [Integer, Float], exclude: Regexp)
+ assert_call('a=1; while cond; break; a=//; end; a.', include: Integer, exclude: Regexp)
+ assert_call('a=1; while cond; a=1.0; break; a=//; end; a.', include: [Integer, Float], exclude: Regexp)
+ assert_call('a=1; ->{ break; a=// }; a.', include: Integer, exclude: Regexp)
+ assert_call('a=1; ->{ a=1.0; break; a=// }; a.', include: [Integer, Float], exclude: Regexp)
+
+ assert_call('a=1; tap{ break; a=// if cond }; a.', include: Integer, exclude: Regexp)
+ assert_call('a=1; tap{ next; a=// if cond }; a.', include: Integer, exclude: Regexp)
+ assert_call('a=1; while cond; break; a=// if cond; end; a.', include: Integer, exclude: Regexp)
+ assert_call('a=1; ->{ break; a=// if cond }; a.', include: Integer, exclude: Regexp)
+
+ assert_call('a=1; tap{if cond; a=:a; break; a=""; end; a.', include: Integer, exclude: [Symbol, String])
+ assert_call('a=1; tap{if cond; a=:a; break; a=""; end; a=//}; a.', include: [Integer, Symbol, Regexp], exclude: String)
+ assert_call('a=1; tap{if cond; a=:a; break; a=""; else; break; end; a=//}; a.', include: [Integer, Symbol], exclude: [String, Regexp])
+ assert_call('a=1; tap{if cond; a=:a; next; a=""; end; a.', include: Integer, exclude: [Symbol, String])
+ assert_call('a=1; tap{if cond; a=:a; next; a=""; end; a=//}; a.', include: [Integer, Symbol, Regexp], exclude: String)
+ assert_call('a=1; tap{if cond; a=:a; next; a=""; else; next; end; a=//}; a.', include: [Integer, Symbol], exclude: [String, Regexp])
+ assert_call('def f(a=1); if cond; a=:a; return; a=""; end; a.', include: Integer, exclude: [Symbol, String])
+ assert_call('a=1; while cond; if cond; a=:a; break; a=""; end; a.', include: Integer, exclude: [Symbol, String])
+ assert_call('a=1; while cond; if cond; a=:a; break; a=""; end; a=//; end; a.', include: [Integer, Symbol, Regexp], exclude: String)
+ assert_call('a=1; while cond; if cond; a=:a; break; a=""; else; break; end; a=//; end; a.', include: [Integer, Symbol], exclude: [String, Regexp])
+ assert_call('a=1; ->{ if cond; a=:a; break; a=""; end; a.', include: Integer, exclude: [Symbol, String])
+ assert_call('a=1; ->{ if cond; a=:a; break; a=""; end; a=// }; a.', include: [Integer, Symbol, Regexp], exclude: String)
+ assert_call('a=1; ->{ if cond; a=:a; break; a=""; else; break; end; a=// }; a.', include: [Integer, Symbol], exclude: [String, Regexp])
+
+ # continue evaluation on terminated branch
+ assert_call('a=1; tap{ a=1.0; break; a=// if cond; a.', include: [Regexp, Float], exclude: Integer)
+ assert_call('a=1; tap{ a=1.0; next; a=// if cond; a.', include: [Regexp, Float], exclude: Integer)
+ assert_call('a=1; ->{ a=1.0; break; a=// if cond; a.', include: [Regexp, Float], exclude: Integer)
+ assert_call('a=1; while cond; a=1.0; break; a=// if cond; a.', include: [Regexp, Float], exclude: Integer)
+ end
+
+ def test_to_str_to_int
+ sobj = Struct.new(:to_str).new('a')
+ iobj = Struct.new(:to_int).new(1)
+ binding = Kernel.binding
+ assert_equal String, ([] * sobj).class
+ assert_equal Array, ([] * iobj).class
+ assert_call('([]*sobj).', include: String, exclude: Array, binding: binding)
+ assert_call('([]*iobj).', include: Array, exclude: String, binding: binding)
+ end
+
+ def test_method_select
+ assert_call('([]*4).', include: Array, exclude: String)
+ assert_call('([]*"").', include: String, exclude: Array)
+ assert_call('([]*unknown).', include: [String, Array])
+ assert_call('p(1).', include: Integer)
+ assert_call('p(1, 2).', include: Array, exclude: Integer)
+ assert_call('2.times.', include: Enumerator, exclude: Integer)
+ assert_call('2.times{}.', include: Integer, exclude: Enumerator)
+ end
+
+ def test_interface_match_var
+ assert_call('([1]+[:a]+["a"]).sample.', include: [Integer, String, Symbol])
+ end
+
+ def test_lvar_scope
+ code = <<~RUBY
+ tap { a = :never }
+ a = 1 if x?
+ tap {|a| a = :never }
+ tap { a = 'maybe' }
+ a = {} if x?
+ a.
+ RUBY
+ assert_call(code, include: [Hash, Integer, String], exclude: [Symbol])
+ end
+
+ def test_lvar_scope_complex
+ assert_call('if cond; a = 1; else; tap { a = :a }; end; a.', include: [NilClass, Integer, Symbol], exclude: [Object])
+ assert_call('def f; if cond; a = 1; return; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object])
+ assert_call('def f; if cond; return; a = 1; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object])
+ assert_call('def f; if cond; return; if cond; return; a = 1; end; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object])
+ assert_call('def f; if cond; return; if cond; return; a = 1; end; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object])
+ end
+
+ def test_gvar_no_scope
+ code = <<~RUBY
+ tap { $a = :maybe }
+ $a = 'maybe' if x?
+ $a.
+ RUBY
+ assert_call(code, include: [Symbol, String])
+ end
+
+ def test_ivar_no_scope
+ code = <<~RUBY
+ tap { @a = :maybe }
+ @a = 'maybe' if x?
+ @a.
+ RUBY
+ assert_call(code, include: [Symbol, String])
+ end
+
+ def test_massign
+ assert_call('(a,=1).', include: Integer)
+ assert_call('(a,=[*1])[0].', include: Integer)
+ assert_call('(a,=[1,2])[0].', include: Integer)
+ assert_call('a,=[1,2]; a.', include: Integer, exclude: Array)
+ assert_call('a,b=[1,2]; a.', include: Integer, exclude: Array)
+ assert_call('a,b=[1,2]; b.', include: Integer, exclude: Array)
+ assert_call('a,*,b=[1,2]; a.', include: Integer, exclude: Array)
+ assert_call('a,*,b=[1,2]; b.', include: Integer, exclude: Array)
+ assert_call('a,*b=[1,2]; a.', include: Integer, exclude: Array)
+ assert_call('a,*b=[1,2]; b.', include: Array, exclude: Integer)
+ assert_call('a,*b=[1,2]; b.sample.', include: Integer)
+ assert_call('a,*,(*)=[1,2]; a.', include: Integer)
+ assert_call('*a=[1,2]; a.', include: Array, exclude: Integer)
+ assert_call('*a=[1,2]; a.sample.', include: Integer)
+ assert_call('a,*b,c=[1,2,3]; b.', include: Array, exclude: Integer)
+ assert_call('a,*b,c=[1,2,3]; b.sample.', include: Integer)
+ assert_call('a,b=(cond)?[1,2]:[:a,:b]; a.', include: [Integer, Symbol])
+ assert_call('a,b=(cond)?[1,2]:[:a,:b]; b.', include: [Integer, Symbol])
+ assert_call('a,b=(cond)?[1,2]:"s"; a.', include: [Integer, String])
+ assert_call('a,b=(cond)?[1,2]:"s"; b.', include: Integer, exclude: String)
+ assert_call('a,*b=(cond)?[1,2]:"s"; a.', include: [Integer, String])
+ assert_call('a,*b=(cond)?[1,2]:"s"; b.', include: Array, exclude: [Integer, String])
+ assert_call('a,*b=(cond)?[1,2]:"s"; b.sample.', include: Integer, exclude: String)
+ assert_call('*a=(cond)?[1,2]:"s"; a.', include: Array, exclude: [Integer, String])
+ assert_call('*a=(cond)?[1,2]:"s"; a.sample.', include: [Integer, String])
+ assert_call('a,(b,),c=[1,[:a],4]; b.', include: Symbol)
+ assert_call('a,(b,(c,))=1; a.', include: Integer)
+ assert_call('a,(b,(*c))=1; c.', include: Array)
+ assert_call('(a=1).b, c = 1; a.', include: Integer)
+ assert_call('a, ((b=1).c, d) = 1; b.', include: Integer)
+ assert_call('a, b[c=1] = 1; c.', include: Integer)
+ assert_call('a, b[*(c=1)] = 1; c.', include: Integer)
+ # incomplete massign
+ assert_analyze_type('a,b', :lvar_or_method, 'b')
+ assert_call('(a=1).b, a.', include: Integer)
+ assert_call('a=1; *a.', include: Integer)
+ end
+
+ def test_field_assign
+ assert_call('(a.!=1).', exclude: Integer)
+ assert_call('(a.b=1).', include: Integer, exclude: NilClass)
+ assert_call('(a&.b=1).', include: Integer)
+ assert_call('(nil&.b=1).', include: NilClass)
+ assert_call('(a[]=1).', include: Integer)
+ assert_call('(a[b]=1).', include: Integer)
+ assert_call('(a.[]=1).', exclude: Integer)
+ end
+
+ def test_def
+ assert_call('def f; end.', include: Symbol)
+ assert_call('s=""; def s.f; self.', include: String)
+ assert_call('def (a="").f; end; a.', include: String)
+ assert_call('def f(a=1); a.', include: Integer)
+ assert_call('def f(**nil); 1.', include: Integer)
+ assert_call('def f((*),*); 1.', include: Integer)
+ assert_call('def f(a,*b); b.', include: Array)
+ assert_call('def f(a,x:1); x.', include: Integer)
+ assert_call('def f(a,x:,**); 1.', include: Integer)
+ assert_call('def f(a,x:,**y); y.', include: Hash)
+ assert_call('def f((*a)); a.', include: Array)
+ assert_call('def f(a,b=1,*c,d,x:0,y:,**z,&e); e.arity.', include: Integer)
+ assert_call('def f(...); 1.', include: Integer)
+ assert_call('def f(a,...); 1.', include: Integer)
+ assert_call('def f(...); g(...); 1.', include: Integer)
+ assert_call('def f(*,**,&); g(*,**,&); 1.', include: Integer)
+ assert_call('def f(*,**,&); {**}.', include: Hash)
+ assert_call('def f(*,**,&); [*,**].', include: Array)
+ assert_call('class Array; def f; self.', include: Array)
+ end
+
+ def test_defined
+ assert_call('defined?(a.b+c).', include: [String, NilClass])
+ assert_call('defined?(a = 1); tap { a = 1.0 }; a.', include: [Integer, Float, NilClass])
+ end
+
+ def test_ternary_operator
+ assert_call('condition ? 1.chr.', include: [String])
+ assert_call('condition ? value : 1.chr.', include: [String])
+ assert_call('condition ? cond ? cond ? value : cond ? value : 1.chr.', include: [String])
+ end
+
+ def test_block_parameter
+ assert_call('method { |arg = 1.chr.', include: [String])
+ assert_call('method do |arg = 1.chr.', include: [String])
+ assert_call('method { |arg1 = 1.|(2|3), arg2 = 1.chr.', include: [String])
+ assert_call('method do |arg1 = 1.|(2|3), arg2 = 1.chr.', include: [String])
+ end
+
+ def test_self
+ integer_binding = 1.instance_eval { Kernel.binding }
+ assert_call('self.', include: [Integer], binding: integer_binding)
+ string = +''
+ string_binding = string.instance_eval { Kernel.binding }
+ assert_call('self.', include: [string.singleton_class], binding: string_binding)
+ object = Object.new
+ object.instance_eval { @int = 1; @string = string }
+ object_binding = object.instance_eval { Kernel.binding }
+ assert_call('self.', include: [object.singleton_class], binding: object_binding)
+ assert_call('@int.', include: [Integer], binding: object_binding)
+ assert_call('@string.', include: [String], binding: object_binding)
+ end
+
+ def test_optional_chain
+ assert_call('[1,nil].sample.', include: [Integer, NilClass])
+ assert_call('[1,nil].sample&.', include: [Integer], exclude: [NilClass])
+ assert_call('[1,nil].sample.chr.', include: [String], exclude: [NilClass])
+ assert_call('[1,nil].sample&.chr.', include: [String, NilClass])
+ assert_call('[1,nil].sample.chr&.ord.', include: [Integer], exclude: [NilClass])
+ assert_call('a = 1; b.c(a = :a); a.', include: [Symbol], exclude: [Integer])
+ assert_call('a = 1; b&.c(a = :a); a.', include: [Integer, Symbol])
+ end
+
+ def test_class_module
+ assert_call('class (1.', include: Integer)
+ assert_call('class (a=1)::B; end; a.', include: Integer)
+ assert_call('class Array; 1; end.', include: Integer)
+ assert_call('class ::Array; 1; end.', include: Integer)
+ assert_call('class Array::A; 1; end.', include: Integer)
+ assert_call('class Array; self.new.', include: Array)
+ assert_call('class ::Array; self.new.', include: Array)
+ assert_call('class Array::A; self.', include: Class)
+ assert_call('class (a=1)::A; end; a.', include: Integer)
+ assert_call('module M; 1; end.', include: Integer)
+ assert_call('module ::M; 1; end.', include: Integer)
+ assert_call('module Array::M; 1; end.', include: Integer)
+ assert_call('module M; self.', include: Module)
+ assert_call('module Array::M; self.', include: Module)
+ assert_call('module ::M; self.', include: Module)
+ assert_call('module (a=1)::M; end; a.', include: Integer)
+ assert_call('class << Array; 1; end.', include: Integer)
+ assert_call('class << a; 1; end.', include: Integer)
+ assert_call('a = ""; class << a; self.superclass.', include: Class)
+ end
+
+ def test_constant_path
+ assert_call('class A; X=1; class B; X=""; X.', include: String, exclude: Integer)
+ assert_call('class A; X=1; class B; X=""; end; X.', include: Integer, exclude: String)
+ assert_call('class A; class B; X=1; end; end; class A; class B; X.', include: Integer)
+ assert_call('module IRB; VERSION.', include: String)
+ assert_call('module IRB; IRB::VERSION.', include: String)
+ assert_call('module IRB; VERSION=1; VERSION.', include: Integer)
+ assert_call('module IRB; VERSION=1; IRB::VERSION.', include: Integer)
+ assert_call('module IRB; module A; VERSION.', include: String)
+ assert_call('module IRB; module A; VERSION=1; VERSION.', include: Integer)
+ assert_call('module IRB; module A; VERSION=1; IRB::VERSION.', include: String)
+ assert_call('module IRB; module A; VERSION=1; end; VERSION.', include: String)
+ assert_call('module IRB; IRB=1; IRB.', include: Integer)
+ assert_call('module IRB; IRB=1; ::IRB::VERSION.', include: String)
+ module_binding = eval 'module ::IRB; binding; end'
+ assert_call('VERSION.', include: NilClass)
+ assert_call('VERSION.', include: String, binding: module_binding)
+ assert_call('IRB::VERSION.', include: String, binding: module_binding)
+ assert_call('A = 1; module M; A += 0.5; A.', include: Float)
+ assert_call('::A = 1; module M; A += 0.5; A.', include: Float)
+ assert_call('::A = 1; module M; A += 0.5; ::A.', include: Integer)
+ assert_call('IRB::A = 1; IRB::A += 0.5; IRB::A.', include: Float)
+ end
+
+ def test_literal
+ assert_call('1.', include: Integer)
+ assert_call('1.0.', include: Float)
+ assert_call('1r.', include: Rational)
+ assert_call('1i.', include: Complex)
+ assert_call('true.', include: TrueClass)
+ assert_call('false.', include: FalseClass)
+ assert_call('nil.', include: NilClass)
+ assert_call('().', include: NilClass)
+ assert_call('//.', include: Regexp)
+ assert_call('/#{a=1}/.', include: Regexp)
+ assert_call('/#{a=1}/; a.', include: Integer)
+ assert_call(':a.', include: Symbol)
+ assert_call(':"#{a=1}".', include: Symbol)
+ assert_call(':"#{a=1}"; a.', include: Integer)
+ assert_call('"".', include: String)
+ assert_call('"#$a".', include: String)
+ assert_call('("a" "b").', include: String)
+ assert_call('"#{a=1}".', include: String)
+ assert_call('"#{a=1}"; a.', include: Integer)
+ assert_call('``.', include: String)
+ assert_call('`#{a=1}`.', include: String)
+ assert_call('`#{a=1}`; a.', include: Integer)
+ end
+
+ def test_redo_retry_yield_super
+ assert_call('a=nil; tap do a=1; redo; a=1i; end; a.', include: Integer, exclude: Complex)
+ assert_call('a=nil; tap do a=1; retry; a=1i; end; a.', include: Integer, exclude: Complex)
+ assert_call('a = 0; a = yield; a.', include: Object, exclude: Integer)
+ assert_call('yield 1,(a=1); a.', include: Integer)
+ assert_call('a = 0; a = super; a.', include: Object, exclude: Integer)
+ assert_call('a = 0; a = super(1); a.', include: Object, exclude: Integer)
+ assert_call('super 1,(a=1); a.', include: Integer)
+ end
+
+ def test_rarely_used_syntax
+ # FlipFlop
+ assert_call('if (a=1).even?..(a=1.0).even; a.', include: [Integer, Float])
+ # MatchLastLine
+ assert_call('if /regexp/; 1.', include: Integer)
+ assert_call('if /reg#{a=1}exp/; a.', include: Integer)
+ # BlockLocalVariable
+ assert_call('tap do |i;a| a=1; a.', include: Integer)
+ # BEGIN{} END{}
+ assert_call('BEGIN{1.', include: Integer)
+ assert_call('END{1.', include: Integer)
+ # MatchWrite
+ assert_call('a=1; /(?<a>)/=~b; a.', include: [String, NilClass], exclude: Integer)
+ # OperatorWrite with block `a[&b]+=c`
+ assert_call('a=[1]; (a[0,&:to_a]+=1.0).', include: Float)
+ assert_call('a=[1]; (a[0,&b]+=1.0).', include: Float)
+ end
+
+ def test_hash
+ assert_call('{}.', include: Hash)
+ assert_call('{**a}.', include: Hash)
+ assert_call('{ rand: }.values.sample.', include: Float)
+ assert_call('rand=""; { rand: }.values.sample.', include: String, exclude: Float)
+ assert_call('{ 1 => 1.0 }.keys.sample.', include: Integer, exclude: Float)
+ assert_call('{ 1 => 1.0 }.values.sample.', include: Float, exclude: Integer)
+ assert_call('a={1=>1.0}; {"a"=>1i,**a}.keys.sample.', include: [Integer, String])
+ assert_call('a={1=>1.0}; {"a"=>1i,**a}.values.sample.', include: [Float, Complex])
+ end
+
+ def test_array
+ assert_call('[1,2,3].sample.', include: Integer)
+ assert_call('a = 1.0; [1,2,a].sample.', include: [Integer, Float])
+ assert_call('a = [1.0]; [1,2,*a].sample.', include: [Integer, Float])
+ end
+
+ def test_numbered_parameter
+ assert_call('loop{_1.', include: NilClass)
+ assert_call('1.tap{_1.', include: Integer)
+ assert_call('1.tap{_3.', include: NilClass, exclude: Integer)
+ assert_call('[:a,1].tap{_1.', include: Array, exclude: [Integer, Symbol])
+ assert_call('[:a,1].tap{_2.', include: [Symbol, Integer], exclude: Array)
+ assert_call('[:a,1].tap{_2; _1.', include: [Symbol, Integer], exclude: Array)
+ assert_call('[:a].each_with_index{_1.', include: Symbol, exclude: [Integer, Array])
+ assert_call('[:a].each_with_index{_2; _1.', include: Symbol, exclude: [Integer, Array])
+ assert_call('[:a].each_with_index{_2.', include: Integer, exclude: Symbol)
+ end
+
+ def test_if_unless
+ assert_call('if cond; 1; end.', include: Integer)
+ assert_call('unless true; 1; end.', include: Integer)
+ assert_call('a=1; (a=1.0) if cond; a.', include: [Integer, Float])
+ assert_call('a=1; (a=1.0) unless cond; a.', include: [Integer, Float])
+ assert_call('a=1; 123 if (a=1.0).foo; a.', include: Float, exclude: Integer)
+ assert_call('if cond; a=1; end; a.', include: [Integer, NilClass])
+ assert_call('a=1; if cond; a=1.0; elsif cond; a=1r; else; a=1i; end; a.', include: [Float, Rational, Complex], exclude: Integer)
+ assert_call('a=1; if cond; a=1.0; else; a.', include: Integer, exclude: Float)
+ assert_call('a=1; if (a=1.0).foo; a.', include: Float, exclude: Integer)
+ assert_call('a=1; if (a=1.0).foo; end; a.', include: Float, exclude: Integer)
+ assert_call('a=1; if (a=1.0).foo; else; a.', include: Float, exclude: Integer)
+ assert_call('a=1; if (a=1.0).foo; elsif a.', include: Float, exclude: Integer)
+ assert_call('a=1; if (a=1.0).foo; elsif (a=1i); else; a.', include: Complex, exclude: [Integer, Float])
+ end
+
+ def test_while_until
+ assert_call('while cond; 123; end.', include: NilClass)
+ assert_call('until cond; 123; end.', include: NilClass)
+ assert_call('a=1; a=1.0 while cond; a.', include: [Integer, Float])
+ assert_call('a=1; a=1.0 until cond; a.', include: [Integer, Float])
+ assert_call('a=1; 1 while (a=1.0).foo; a.', include: Float, exclude: Integer)
+ assert_call('while cond; break 1; end.', include: Integer)
+ assert_call('while cond; a=1; end; a.', include: Integer)
+ assert_call('a=1; while cond; a=1.0; end; a.', include: [Integer, Float])
+ assert_call('a=1; while (a=1.0).foo; end; a.', include: Float, exclude: Integer)
+ end
+
+ def test_for
+ assert_call('for i in [1,2,3]; i.', include: Integer)
+ assert_call('for i,j in [1,2,3]; i.', include: Integer)
+ assert_call('for *,(*) in [1,2,3]; 1.', include: Integer)
+ assert_call('for *i in [1,2,3]; i.sample.', include: Integer)
+ assert_call('for (a=1).b in [1,2,3]; a.', include: Integer)
+ assert_call('for Array::B in [1,2,3]; Array::B.', include: Integer)
+ assert_call('for A in [1,2,3]; A.', include: Integer)
+ assert_call('for $a in [1,2,3]; $a.', include: Integer)
+ assert_call('for @a in [1,2,3]; @a.', include: Integer)
+ assert_call('for i in [1,2,3]; end.', include: Array)
+ assert_call('for i in [1,2,3]; break 1.0; end.', include: [Array, Float])
+ assert_call('i = 1.0; for i in [1,2,3]; end; i.', include: [Integer, Float])
+ assert_call('a = 1.0; for i in [1,2,3]; a = 1i; end; a.', include: [Float, Complex])
+ end
+
+ def test_special_var
+ assert_call('__FILE__.', include: String)
+ assert_call('__LINE__.', include: Integer)
+ assert_call('__ENCODING__.', include: Encoding)
+ assert_call('$1.', include: String)
+ assert_call('$&.', include: String)
+ end
+
+ def test_and_or
+ assert_call('(1&&1.0).', include: Float, exclude: Integer)
+ assert_call('(nil&&1.0).', include: NilClass)
+ assert_call('(nil||1).', include: Integer)
+ assert_call('(1||1.0).', include: Float)
+ end
+
+ def test_opwrite
+ assert_call('a=[]; a*=1; a.', include: Array)
+ assert_call('a=[]; a*=""; a.', include: String)
+ assert_call('a=[1,false].sample; a||=1.0; a.', include: [Integer, Float])
+ assert_call('a=1; a&&=1.0; a.', include: Float, exclude: Integer)
+ assert_call('(a=1).b*=1; a.', include: Integer)
+ assert_call('(a=1).b||=1; a.', include: Integer)
+ assert_call('(a=1).b&&=1; a.', include: Integer)
+ assert_call('[][a=1]&&=1; a.', include: Integer)
+ assert_call('[][a=1]||=1; a.', include: Integer)
+ assert_call('[][a=1]+=1; a.', include: Integer)
+ assert_call('([1][0]+=1.0).', include: Float)
+ assert_call('([1.0][0]+=1).', include: Float)
+ assert_call('A=nil; A||=1; A.', include: Integer)
+ assert_call('A=1; A&&=1.0; A.', include: Float)
+ assert_call('A=1; A+=1.0; A.', include: Float)
+ assert_call('Array::A||=1; Array::A.', include: Integer)
+ assert_call('Array::A=1; Array::A&&=1.0; Array::A.', include: Float)
+ end
+
+ def test_case_when
+ assert_call('case x; when A; 1; when B; 1.0; end.', include: [Integer, Float, NilClass])
+ assert_call('case x; when A; 1; when B; 1.0; else; 1r; end.', include: [Integer, Float, Rational], exclude: NilClass)
+ assert_call('case; when (a=1); a.', include: Integer)
+ assert_call('case x; when (a=1); a.', include: Integer)
+ assert_call('a=1; case (a=1.0); when A; a.', include: Float, exclude: Integer)
+ assert_call('a=1; case (a=1.0); when A; end; a.', include: Float, exclude: Integer)
+ assert_call('a=1; case x; when A; a=1.0; else; a=1r; end; a.', include: [Float, Rational], exclude: Integer)
+ assert_call('a=1; case x; when A; a=1.0; when B; a=1r; end; a.', include: [Float, Rational, Integer])
+ end
+
+ def test_case_in
+ assert_call('case x; in A; 1; in B; 1.0; end.', include: [Integer, Float], exclude: NilClass)
+ assert_call('case x; in A; 1; in B; 1.0; else; 1r; end.', include: [Integer, Float, Rational], exclude: NilClass)
+ assert_call('a=""; case 1; in A; a=1; in B; a=1.0; end; a.', include: [Integer, Float], exclude: String)
+ assert_call('a=""; case 1; in A; a=1; in B; a=1.0; else; a=1r; end; a.', include: [Integer, Float, Rational], exclude: String)
+ assert_call('case 1; in x; x.', include: Integer)
+ assert_call('case x; in A if (a=1); a.', include: Integer)
+ assert_call('case x; in ^(a=1); a.', include: Integer)
+ assert_call('case x; in [1, String => a, 2]; a.', include: String)
+ assert_call('case x; in [*a, 1]; a.', include: Array)
+ assert_call('case x; in [1, *a]; a.', include: Array)
+ assert_call('case x; in [*a, 1, *b]; a.', include: Array)
+ assert_call('case x; in [*a, 1, *b]; b.', include: Array)
+ assert_call('case x; in {a: {b: **c}}; c.', include: Hash)
+ assert_call('case x; in (String | { x: Integer, y: ^$a }) => a; a.', include: [String, Hash])
+ end
+
+ def test_pattern_match
+ assert_call('1 in a; a.', include: Integer)
+ assert_call('a=1; x in String=>a; a.', include: [Integer, String])
+ assert_call('a=1; x=>String=>a; a.', include: String, exclude: Integer)
+ end
+
+ def test_bottom_type_termination
+ assert_call('a=1; tap { raise; a=1.0; a.', include: Float)
+ assert_call('a=1; tap { loop{}; a=1.0; a.', include: Float)
+ assert_call('a=1; tap { raise; a=1.0 } a.', include: Integer, exclude: Float)
+ assert_call('a=1; tap { loop{}; a=1.0 } a.', include: Integer, exclude: Float)
+ end
+
+ def test_call_parameter
+ assert_call('f((x=1),*b,c:1,**d,&e); x.', include: Integer)
+ assert_call('f(a,*(x=1),c:1,**d,&e); x.', include: Integer)
+ assert_call('f(a,*b,(x=1):1,**d,&e); x.', include: Integer)
+ assert_call('f(a,*b,c:(x=1),**d,&e); x.', include: Integer)
+ assert_call('f(a,*b,c:1,**(x=1),&e); x.', include: Integer)
+ assert_call('f(a,*b,c:1,**d,&(x=1)); x.', include: Integer)
+ assert_call('f((x=1)=>1); x.', include: Integer)
+ end
+
+ def test_block_args
+ assert_call('[1,2,3].tap{|a| a.', include: Array)
+ assert_call('[1,2,3].tap{|a,b| a.', include: Integer)
+ assert_call('[1,2,3].tap{|(a,b)| a.', include: Integer)
+ assert_call('[1,2,3].tap{|a,*b| b.', include: Array)
+ assert_call('[1,2,3].tap{|a=1.0| a.', include: [Array, Float])
+ assert_call('[1,2,3].tap{|a,**b| b.', include: Hash)
+ assert_call('1.tap{|(*),*,**| 1.', include: Integer)
+ end
+
+ def test_array_aref
+ assert_call('[1][0..].', include: [Array, NilClass], exclude: Integer)
+ assert_call('[1][0].', include: Integer, exclude: [Array, NilClass])
+ assert_call('[1].[](0).', include: Integer, exclude: [Array, NilClass])
+ assert_call('[1].[](0){}.', include: Integer, exclude: [Array, NilClass])
+ end
+ end
+end
diff --git a/test/irb/type_completion/test_type_completor.rb b/test/irb/type_completion/test_type_completor.rb
new file mode 100644
index 0000000000..eed400b3e2
--- /dev/null
+++ b/test/irb/type_completion/test_type_completor.rb
@@ -0,0 +1,181 @@
+# frozen_string_literal: true
+
+# Run test only when Ruby >= 3.0 and %w[prism rbs] are available
+return unless RUBY_VERSION >= '3.0.0'
+return if RUBY_ENGINE == 'truffleruby' # needs endless method definition
+begin
+ require 'prism'
+ require 'rbs'
+rescue LoadError
+ return
+end
+
+require 'irb/type_completion/completor'
+require_relative '../helper'
+
+module TestIRB
+ class TypeCompletorTest < TestCase
+ def setup
+ IRB::TypeCompletion::Types.load_rbs_builder unless IRB::TypeCompletion::Types.rbs_builder
+ @completor = IRB::TypeCompletion::Completor.new
+ end
+
+ def empty_binding
+ binding
+ end
+
+ TARGET_REGEXP = /(@@|@|\$)?[a-zA-Z_]*[!?=]?$/
+
+ def assert_completion(code, binding: empty_binding, include: nil, exclude: nil)
+ raise ArgumentError if include.nil? && exclude.nil?
+ target = code[TARGET_REGEXP]
+ candidates = @completor.completion_candidates(code.delete_suffix(target), target, '', bind: binding)
+ assert ([*include] - candidates).empty?, "Expected #{candidates} to include #{include}" if include
+ assert (candidates & [*exclude]).empty?, "Expected #{candidates} not to include #{exclude}" if exclude
+ end
+
+ def assert_doc_namespace(code, namespace, binding: empty_binding)
+ target = code[TARGET_REGEXP]
+ preposing = code.delete_suffix(target)
+ @completor.completion_candidates(preposing, target, '', bind: binding)
+ assert_equal namespace, @completor.doc_namespace(preposing, target, '', bind: binding)
+ end
+
+ def test_require
+ assert_completion("require '", include: 'set')
+ assert_completion("require 's", include: 'set')
+ Dir.chdir(__dir__ + "/../../..") do
+ assert_completion("require_relative 'l", include: 'lib/irb')
+ end
+ # Incomplete double quote string is InterpolatedStringNode
+ assert_completion('require "', include: 'set')
+ assert_completion('require "s', include: 'set')
+ end
+
+ def test_method_block_sym
+ assert_completion('[1].map(&:', include: 'abs')
+ assert_completion('[:a].map(&:', exclude: 'abs')
+ assert_completion('[1].map(&:a', include: 'abs')
+ assert_doc_namespace('[1].map(&:abs', 'Integer#abs')
+ end
+
+ def test_symbol
+ sym = :test_completion_symbol
+ assert_completion(":test_com", include: sym.to_s)
+ end
+
+ def test_call
+ assert_completion('1.', include: 'abs')
+ assert_completion('1.a', include: 'abs')
+ assert_completion('ran', include: 'rand')
+ assert_doc_namespace('1.abs', 'Integer#abs')
+ assert_doc_namespace('Integer.sqrt', 'Integer.sqrt')
+ assert_doc_namespace('rand', 'TestIRB::TypeCompletorTest#rand')
+ assert_doc_namespace('Object::rand', 'Object.rand')
+ end
+
+ def test_lvar
+ bind = eval('lvar = 1; binding')
+ assert_completion('lva', binding: bind, include: 'lvar')
+ assert_completion('lvar.', binding: bind, include: 'abs')
+ assert_completion('lvar.a', binding: bind, include: 'abs')
+ assert_completion('lvar = ""; lvar.', binding: bind, include: 'ascii_only?')
+ assert_completion('lvar = ""; lvar.', include: 'ascii_only?')
+ assert_doc_namespace('lvar', 'Integer', binding: bind)
+ assert_doc_namespace('lvar.abs', 'Integer#abs', binding: bind)
+ assert_doc_namespace('lvar = ""; lvar.ascii_only?', 'String#ascii_only?', binding: bind)
+ end
+
+ def test_const
+ assert_completion('Ar', include: 'Array')
+ assert_completion('::Ar', include: 'Array')
+ assert_completion('IRB::V', include: 'VERSION')
+ assert_completion('FooBar=1; F', include: 'FooBar')
+ assert_completion('::FooBar=1; ::F', include: 'FooBar')
+ assert_doc_namespace('Array', 'Array')
+ assert_doc_namespace('Array = 1; Array', 'Integer')
+ assert_doc_namespace('Object::Array', 'Array')
+ assert_completion('::', include: 'Array')
+ assert_completion('class ::', include: 'Array')
+ assert_completion('module IRB; class T', include: ['TypeCompletion', 'TracePoint'])
+ end
+
+ def test_gvar
+ assert_completion('$', include: '$stdout')
+ assert_completion('$s', include: '$stdout')
+ assert_completion('$', exclude: '$foobar')
+ assert_completion('$foobar=1; $', include: '$foobar')
+ assert_doc_namespace('$foobar=1; $foobar', 'Integer')
+ assert_doc_namespace('$stdout', 'IO')
+ assert_doc_namespace('$stdout=1; $stdout', 'Integer')
+ end
+
+ def test_ivar
+ bind = Object.new.instance_eval { @foo = 1; binding }
+ assert_completion('@', binding: bind, include: '@foo')
+ assert_completion('@f', binding: bind, include: '@foo')
+ assert_completion('@bar = 1; @', include: '@bar')
+ assert_completion('@bar = 1; @b', include: '@bar')
+ assert_doc_namespace('@bar = 1; @bar', 'Integer')
+ assert_doc_namespace('@foo', 'Integer', binding: bind)
+ assert_doc_namespace('@foo = 1.0; @foo', 'Float', binding: bind)
+ end
+
+ def test_cvar
+ bind = eval('m=Module.new; module m::M; @@foo = 1; binding; end')
+ assert_equal(1, bind.eval('@@foo'))
+ assert_completion('@', binding: bind, include: '@@foo')
+ assert_completion('@@', binding: bind, include: '@@foo')
+ assert_completion('@@f', binding: bind, include: '@@foo')
+ assert_doc_namespace('@@foo', 'Integer', binding: bind)
+ assert_doc_namespace('@@foo = 1.0; @@foo', 'Float', binding: bind)
+ assert_completion('@@bar = 1; @', include: '@@bar')
+ assert_completion('@@bar = 1; @@', include: '@@bar')
+ assert_completion('@@bar = 1; @@b', include: '@@bar')
+ assert_doc_namespace('@@bar = 1; @@bar', 'Integer')
+ end
+
+ def test_basic_object
+ bo = BasicObject.new
+ def bo.foo; end
+ bo.instance_eval { @bar = 1 }
+ bind = binding
+ bo_self_bind = bo.instance_eval { Kernel.binding }
+ assert_completion('bo.', binding: bind, include: 'foo')
+ assert_completion('def bo.baz; self.', binding: bind, include: 'foo')
+ assert_completion('[bo].first.', binding: bind, include: 'foo')
+ assert_doc_namespace('bo', 'BasicObject', binding: bind)
+ assert_doc_namespace('bo.__id__', 'BasicObject#__id__', binding: bind)
+ assert_doc_namespace('v = [bo]; v', 'Array', binding: bind)
+ assert_doc_namespace('v = [bo].first; v', 'BasicObject', binding: bind)
+ bo_self_bind = bo.instance_eval { Kernel.binding }
+ assert_completion('self.', binding: bo_self_bind, include: 'foo')
+ assert_completion('@', binding: bo_self_bind, include: '@bar')
+ assert_completion('@bar.', binding: bo_self_bind, include: 'abs')
+ assert_doc_namespace('self.__id__', 'BasicObject#__id__', binding: bo_self_bind)
+ assert_doc_namespace('@bar', 'Integer', binding: bo_self_bind)
+ if RUBY_VERSION >= '3.2.0' # Needs Class#attached_object to get instance variables from singleton class
+ assert_completion('def bo.baz; @bar.', binding: bind, include: 'abs')
+ assert_completion('def bo.baz; @', binding: bind, include: '@bar')
+ end
+ end
+
+ def test_inspect
+ rbs_builder = IRB::TypeCompletion::Types.rbs_builder
+ assert_match(/TypeCompletion::Completor\(Prism: \d.+, RBS: \d.+\)/, @completor.inspect)
+ IRB::TypeCompletion::Types.instance_variable_set(:@rbs_builder, nil)
+ assert_match(/TypeCompletion::Completor\(Prism: \d.+, RBS: loading\)/, @completor.inspect)
+ IRB::TypeCompletion::Types.instance_variable_set(:@rbs_load_error, StandardError.new('[err]'))
+ assert_match(/TypeCompletion::Completor\(Prism: \d.+, RBS: .+\[err\].+\)/, @completor.inspect)
+ ensure
+ IRB::TypeCompletion::Types.instance_variable_set(:@rbs_builder, rbs_builder)
+ IRB::TypeCompletion::Types.instance_variable_set(:@rbs_load_error, nil)
+ end
+
+ def test_none
+ candidates = @completor.completion_candidates('(', ')', '', bind: binding)
+ assert_equal [], candidates
+ assert_doc_namespace('()', nil)
+ end
+ end
+end
diff --git a/test/irb/type_completion/test_types.rb b/test/irb/type_completion/test_types.rb
new file mode 100644
index 0000000000..7698bd2fc0
--- /dev/null
+++ b/test/irb/type_completion/test_types.rb
@@ -0,0 +1,89 @@
+# frozen_string_literal: true
+
+return unless RUBY_VERSION >= '3.0.0'
+return if RUBY_ENGINE == 'truffleruby' # needs endless method definition
+
+require 'irb/type_completion/types'
+require_relative '../helper'
+
+module TestIRB
+ class TypeCompletionTypesTest < TestCase
+ def test_type_inspect
+ true_type = IRB::TypeCompletion::Types::TRUE
+ false_type = IRB::TypeCompletion::Types::FALSE
+ nil_type = IRB::TypeCompletion::Types::NIL
+ string_type = IRB::TypeCompletion::Types::STRING
+ true_or_false = IRB::TypeCompletion::Types::UnionType[true_type, false_type]
+ array_type = IRB::TypeCompletion::Types::InstanceType.new Array, { Elem: true_or_false }
+ assert_equal 'nil', nil_type.inspect
+ assert_equal 'true', true_type.inspect
+ assert_equal 'false', false_type.inspect
+ assert_equal 'String', string_type.inspect
+ assert_equal 'Array', IRB::TypeCompletion::Types::InstanceType.new(Array).inspect
+ assert_equal 'true | false', true_or_false.inspect
+ assert_equal 'Array[Elem: true | false]', array_type.inspect
+ assert_equal 'Array', array_type.inspect_without_params
+ assert_equal 'Proc', IRB::TypeCompletion::Types::PROC.inspect
+ assert_equal 'Array.itself', IRB::TypeCompletion::Types::SingletonType.new(Array).inspect
+ end
+
+ def test_type_from_object
+ obj = Object.new
+ bo = BasicObject.new
+ def bo.hash; 42; end # Needed to use this object as a hash key
+ arr = [1, 'a']
+ hash = { 'key' => :value }
+ int_type = IRB::TypeCompletion::Types.type_from_object 1
+ obj_type = IRB::TypeCompletion::Types.type_from_object obj
+ arr_type = IRB::TypeCompletion::Types.type_from_object arr
+ hash_type = IRB::TypeCompletion::Types.type_from_object hash
+ bo_type = IRB::TypeCompletion::Types.type_from_object bo
+ bo_arr_type = IRB::TypeCompletion::Types.type_from_object [bo]
+ bo_key_hash_type = IRB::TypeCompletion::Types.type_from_object({ bo => 1 })
+ bo_value_hash_type = IRB::TypeCompletion::Types.type_from_object({ x: bo })
+
+ assert_equal Integer, int_type.klass
+ # Use singleton_class to autocomplete singleton methods
+ assert_equal obj.singleton_class, obj_type.klass
+ assert_equal Object.instance_method(:singleton_class).bind_call(bo), bo_type.klass
+ # Array and Hash are special
+ assert_equal Array, arr_type.klass
+ assert_equal Array, bo_arr_type.klass
+ assert_equal Hash, hash_type.klass
+ assert_equal Hash, bo_key_hash_type.klass
+ assert_equal Hash, bo_value_hash_type.klass
+ assert_equal BasicObject, bo_arr_type.params[:Elem].klass
+ assert_equal BasicObject, bo_key_hash_type.params[:K].klass
+ assert_equal BasicObject, bo_value_hash_type.params[:V].klass
+ assert_equal 'Object', obj_type.inspect
+ assert_equal 'Array[Elem: Integer | String]', arr_type.inspect
+ assert_equal 'Hash[K: String, V: Symbol]', hash_type.inspect
+ assert_equal 'Array.itself', IRB::TypeCompletion::Types.type_from_object(Array).inspect
+ assert_equal 'IRB::TypeCompletion.itself', IRB::TypeCompletion::Types.type_from_object(IRB::TypeCompletion).inspect
+ end
+
+ def test_type_methods
+ s = +''
+ class << s
+ def foobar; end
+ private def foobaz; end
+ end
+ String.define_method(:foobarbaz) {}
+ targets = [:foobar, :foobaz, :foobarbaz]
+ type = IRB::TypeCompletion::Types.type_from_object s
+ assert_equal [:foobar, :foobarbaz], targets & type.methods
+ assert_equal [:foobar, :foobaz, :foobarbaz], targets & type.all_methods
+ assert_equal [:foobarbaz], targets & IRB::TypeCompletion::Types::STRING.methods
+ assert_equal [:foobarbaz], targets & IRB::TypeCompletion::Types::STRING.all_methods
+ ensure
+ String.remove_method :foobarbaz
+ end
+
+ def test_basic_object_methods
+ bo = BasicObject.new
+ def bo.foobar; end
+ type = IRB::TypeCompletion::Types.type_from_object bo
+ assert type.all_methods.include?(:foobar)
+ end
+ end
+end