aboutsummaryrefslogtreecommitdiffstats
path: root/lib
diff options
context:
space:
mode:
authortomoya ishida <tomoyapenguin@gmail.com>2023-11-08 11:46:24 +0900
committergit <svn-admin@ruby-lang.org>2023-11-08 02:46:33 +0000
commite34401046566ad1938b1eec654a6bf69b1319102 (patch)
tree27515a6c6e783818906480910d27fab3c6e41e97 /lib
parent7ed37388fb9c0e85325b4e3db2ffbfca3f4179ad (diff)
downloadruby-e34401046566ad1938b1eec654a6bf69b1319102.tar.gz
[ruby/irb] Type based completion using Prism and RBS
(https://github.com/ruby/irb/pull/708) * Add completor using prism and rbs * Add TypeCompletion test * Switchable completors: RegexpCompletor and TypeCompletion::Completor * Add completion info to irb_info * Complete reserved words * Fix [*] (*) {**} and prism's change of KeywordParameterNode * Fix require, frozen_string_literal * Drop prism<=0.16.0 support * Add Completor.last_completion_error for debug report * Retrieve `self` and `Module.nesting` in more safe way * Support BasicObject * Handle lvar and ivar get exception correctly * Skip ivar reference test of non-self object in ruby < 3.2 * BaseScope to RootScope, move method objects constant under Methods * Remove unused Splat struct * Drop deeply nested array/hash type calculation from actual object. Now, calculation depth is 1 * Refactor loading rbs in test, change preload_in_thread not to cache Thread object * Use new option added in prism 0.17.1 to parse code with localvars * Add Prism version check and warn when :type completor cannot be enabled * build_type_completor should skip truffleruby (because endless method definition is not supported) https://github.com/ruby/irb/commit/1048c7ed7a
Diffstat (limited to 'lib')
-rw-r--r--lib/irb.rb4
-rw-r--r--lib/irb/cmd/irb_info.rb1
-rw-r--r--lib/irb/completion.rb49
-rw-r--r--lib/irb/context.rb41
-rw-r--r--lib/irb/init.rb1
-rw-r--r--lib/irb/input-method.rb15
-rw-r--r--lib/irb/type_completion/completor.rb235
-rw-r--r--lib/irb/type_completion/methods.rb13
-rw-r--r--lib/irb/type_completion/scope.rb412
-rw-r--r--lib/irb/type_completion/type_analyzer.rb1169
-rw-r--r--lib/irb/type_completion/types.rb426
11 files changed, 2339 insertions, 27 deletions
diff --git a/lib/irb.rb b/lib/irb.rb
index d0688e6f9f..655abaf069 100644
--- a/lib/irb.rb
+++ b/lib/irb.rb
@@ -140,6 +140,10 @@ require_relative "irb/debug"
#
# IRB.conf[:USE_AUTOCOMPLETE] = false
#
+# To enable enhanced completion using type information, add the following to your +.irbrc+:
+#
+# IRB.conf[:COMPLETOR] = :type
+#
# === History
#
# By default, irb will store the last 1000 commands you used in
diff --git a/lib/irb/cmd/irb_info.rb b/lib/irb/cmd/irb_info.rb
index 75fdc38676..5b905a09bd 100644
--- a/lib/irb/cmd/irb_info.rb
+++ b/lib/irb/cmd/irb_info.rb
@@ -14,6 +14,7 @@ module IRB
str = "Ruby version: #{RUBY_VERSION}\n"
str += "IRB version: #{IRB.version}\n"
str += "InputMethod: #{IRB.CurrentContext.io.inspect}\n"
+ str += "Completion: #{IRB.CurrentContext.io.respond_to?(:completion_info) ? IRB.CurrentContext.io.completion_info : 'off'}\n"
str += ".irbrc path: #{IRB.rc_file}\n" if File.exist?(IRB.rc_file)
str += "RUBY_PLATFORM: #{RUBY_PLATFORM}\n"
str += "LANG env: #{ENV["LANG"]}\n" if ENV["LANG"] && !ENV["LANG"].empty?
diff --git a/lib/irb/completion.rb b/lib/irb/completion.rb
index 61bdc33587..e3ebe4abff 100644
--- a/lib/irb/completion.rb
+++ b/lib/irb/completion.rb
@@ -9,6 +9,30 @@ require_relative 'ruby-lex'
module IRB
class BaseCompletor # :nodoc:
+
+ # Set of reserved words used by Ruby, you should not use these for
+ # constants or variables
+ ReservedWords = %w[
+ __ENCODING__ __LINE__ __FILE__
+ BEGIN END
+ alias and
+ begin break
+ case class
+ def defined? do
+ else elsif end ensure
+ false for
+ if in
+ module
+ next nil not
+ or
+ redo rescue retry return
+ self super
+ then true
+ undef unless until
+ when while
+ yield
+ ]
+
def completion_candidates(preposing, target, postposing, bind:)
raise NotImplementedError
end
@@ -94,28 +118,9 @@ module IRB
end
}
- # Set of reserved words used by Ruby, you should not use these for
- # constants or variables
- ReservedWords = %w[
- __ENCODING__ __LINE__ __FILE__
- BEGIN END
- alias and
- begin break
- case class
- def defined? do
- else elsif end ensure
- false for
- if in
- module
- next nil not
- or
- redo rescue retry return
- self super
- then true
- undef unless until
- when while
- yield
- ]
+ def inspect
+ 'RegexpCompletor'
+ end
def complete_require_path(target, preposing, postposing)
if target =~ /\A(['"])([^'"]+)\Z/
diff --git a/lib/irb/context.rb b/lib/irb/context.rb
index a20510d73c..5dfe9d0d71 100644
--- a/lib/irb/context.rb
+++ b/lib/irb/context.rb
@@ -86,14 +86,14 @@ module IRB
when nil
if STDIN.tty? && IRB.conf[:PROMPT_MODE] != :INF_RUBY && !use_singleline?
# Both of multiline mode and singleline mode aren't specified.
- @io = RelineInputMethod.new
+ @io = RelineInputMethod.new(build_completor)
else
@io = nil
end
when false
@io = nil
when true
- @io = RelineInputMethod.new
+ @io = RelineInputMethod.new(build_completor)
end
unless @io
case use_singleline?
@@ -149,6 +149,43 @@ module IRB
@command_aliases = IRB.conf[:COMMAND_ALIASES]
end
+ private def build_completor
+ completor_type = IRB.conf[:COMPLETOR]
+ case completor_type
+ when :regexp
+ return RegexpCompletor.new
+ when :type
+ completor = build_type_completor
+ return completor if completor
+ else
+ warn "Invalid value for IRB.conf[:COMPLETOR]: #{completor_type}"
+ end
+ # Fallback to RegexpCompletor
+ RegexpCompletor.new
+ end
+
+ TYPE_COMPLETION_REQUIRED_PRISM_VERSION = '0.17.1'
+
+ private def build_type_completor
+ unless Gem::Version.new(RUBY_VERSION) >= Gem::Version.new('3.0.0') && RUBY_ENGINE != 'truffleruby'
+ warn 'TypeCompletion requires RUBY_VERSION >= 3.0.0'
+ return
+ end
+ begin
+ require 'prism'
+ rescue LoadError => e
+ warn "TypeCompletion requires Prism: #{e.message}"
+ return
+ end
+ unless Gem::Version.new(Prism::VERSION) >= Gem::Version.new(TYPE_COMPLETION_REQUIRED_PRISM_VERSION)
+ warn "TypeCompletion requires Prism::VERSION >= #{TYPE_COMPLETION_REQUIRED_PRISM_VERSION}"
+ return
+ end
+ require 'irb/type_completion/completor'
+ TypeCompletion::Types.preload_in_thread
+ TypeCompletion::Completor.new
+ end
+
def save_history=(val)
IRB.conf[:SAVE_HISTORY] = val
end
diff --git a/lib/irb/init.rb b/lib/irb/init.rb
index d9549420b4..e9111974f0 100644
--- a/lib/irb/init.rb
+++ b/lib/irb/init.rb
@@ -76,6 +76,7 @@ module IRB # :nodoc:
@CONF[:USE_SINGLELINE] = false unless defined?(ReadlineInputMethod)
@CONF[:USE_COLORIZE] = (nc = ENV['NO_COLOR']).nil? || nc.empty?
@CONF[:USE_AUTOCOMPLETE] = ENV.fetch("IRB_USE_AUTOCOMPLETE", "true") != "false"
+ @CONF[:COMPLETOR] = :regexp
@CONF[:INSPECT_MODE] = true
@CONF[:USE_TRACER] = false
@CONF[:USE_LOADER] = false
diff --git a/lib/irb/input-method.rb b/lib/irb/input-method.rb
index cef65b7162..94ad28cd63 100644
--- a/lib/irb/input-method.rb
+++ b/lib/irb/input-method.rb
@@ -193,6 +193,10 @@ module IRB
}
end
+ def completion_info
+ 'RegexpCompletor'
+ end
+
# Reads the next line from this input method.
#
# See IO#gets for more information.
@@ -230,13 +234,13 @@ module IRB
HISTORY = Reline::HISTORY
include HistorySavingAbility
# Creates a new input method object using Reline
- def initialize
+ def initialize(completor)
IRB.__send__(:set_encoding, Reline.encoding_system_needs.name, override: false)
- super
+ super()
@eof = false
- @completor = RegexpCompletor.new
+ @completor = completor
Reline.basic_word_break_characters = BASIC_WORD_BREAK_CHARACTERS
Reline.completion_append_character = nil
@@ -270,6 +274,11 @@ module IRB
end
end
+ def completion_info
+ autocomplete_message = Reline.autocompletion ? 'Autocomplete' : 'Tab Complete'
+ "#{autocomplete_message}, #{@completor.inspect}"
+ end
+
def check_termination(&block)
@check_termination_proc = block
end
diff --git a/lib/irb/type_completion/completor.rb b/lib/irb/type_completion/completor.rb
new file mode 100644
index 0000000000..e893fd8adc
--- /dev/null
+++ b/lib/irb/type_completion/completor.rb
@@ -0,0 +1,235 @@
+# frozen_string_literal: true
+
+require 'prism'
+require 'irb/completion'
+require_relative 'type_analyzer'
+
+module IRB
+ module TypeCompletion
+ class Completor < BaseCompletor # :nodoc:
+ HIDDEN_METHODS = %w[Namespace TypeName] # defined by rbs, should be hidden
+
+ class << self
+ attr_accessor :last_completion_error
+ end
+
+ def inspect
+ name = 'TypeCompletion::Completor'
+ prism_info = "Prism: #{Prism::VERSION}"
+ if Types.rbs_builder
+ "#{name}(#{prism_info}, RBS: #{RBS::VERSION})"
+ elsif Types.rbs_load_error
+ "#{name}(#{prism_info}, RBS: #{Types.rbs_load_error.inspect})"
+ else
+ "#{name}(#{prism_info}, RBS: loading)"
+ end
+ end
+
+ def completion_candidates(preposing, target, _postposing, bind:)
+ @preposing = preposing
+ verbose, $VERBOSE = $VERBOSE, nil
+ code = "#{preposing}#{target}"
+ @result = analyze code, bind
+ name, candidates = candidates_from_result(@result)
+
+ all_symbols_pattern = /\A[ -\/:-@\[-`\{-~]*\z/
+ candidates.map(&:to_s).select { !_1.match?(all_symbols_pattern) && _1.start_with?(name) }.uniq.sort.map do
+ target + _1[name.size..]
+ end
+ rescue SyntaxError, StandardError => e
+ Completor.last_completion_error = e
+ handle_error(e)
+ []
+ ensure
+ $VERBOSE = verbose
+ end
+
+ def doc_namespace(preposing, matched, postposing, bind:)
+ name = matched[/[a-zA-Z_0-9]*[!?=]?\z/]
+ method_doc = -> type do
+ type = type.types.find { _1.all_methods.include? name.to_sym }
+ case type
+ when Types::SingletonType
+ "#{Types.class_name_of(type.module_or_class)}.#{name}"
+ when Types::InstanceType
+ "#{Types.class_name_of(type.klass)}##{name}"
+ end
+ end
+ call_or_const_doc = -> type do
+ if name =~ /\A[A-Z]/
+ type = type.types.grep(Types::SingletonType).find { _1.module_or_class.const_defined?(name) }
+ type.module_or_class == Object ? name : "#{Types.class_name_of(type.module_or_class)}::#{name}" if type
+ else
+ method_doc.call(type)
+ end
+ end
+
+ value_doc = -> type do
+ return unless type
+ type.types.each do |t|
+ case t
+ when Types::SingletonType
+ return Types.class_name_of(t.module_or_class)
+ when Types::InstanceType
+ return Types.class_name_of(t.klass)
+ end
+ end
+ nil
+ end
+
+ case @result
+ in [:call_or_const, type, _name, _self_call]
+ call_or_const_doc.call type
+ in [:const, type, _name, scope]
+ if type
+ call_or_const_doc.call type
+ else
+ value_doc.call scope[name]
+ end
+ in [:gvar, _name, scope]
+ value_doc.call scope["$#{name}"]
+ in [:ivar, _name, scope]
+ value_doc.call scope["@#{name}"]
+ in [:cvar, _name, scope]
+ value_doc.call scope["@@#{name}"]
+ in [:call, type, _name, _self_call]
+ method_doc.call type
+ in [:lvar_or_method, _name, scope]
+ if scope.local_variables.include?(name)
+ value_doc.call scope[name]
+ else
+ method_doc.call scope.self_type
+ end
+ else
+ end
+ end
+
+ def candidates_from_result(result)
+ candidates = case result
+ in [:require, name]
+ retrieve_files_to_require_from_load_path
+ in [:require_relative, name]
+ retrieve_files_to_require_relative_from_current_dir
+ in [:call_or_const, type, name, self_call]
+ ((self_call ? type.all_methods : type.methods).map(&:to_s) - HIDDEN_METHODS) | type.constants
+ in [:const, type, name, scope]
+ if type
+ scope_constants = type.types.flat_map do |t|
+ scope.table_module_constants(t.module_or_class) if t.is_a?(Types::SingletonType)
+ end
+ (scope_constants.compact | type.constants.map(&:to_s)).sort
+ else
+ scope.constants.sort | ReservedWords
+ end
+ in [:ivar, name, scope]
+ ivars = scope.instance_variables.sort
+ name == '@' ? ivars + scope.class_variables.sort : ivars
+ in [:cvar, name, scope]
+ scope.class_variables
+ in [:gvar, name, scope]
+ scope.global_variables
+ in [:symbol, name]
+ Symbol.all_symbols.map { _1.inspect[1..] }
+ in [:call, type, name, self_call]
+ (self_call ? type.all_methods : type.methods).map(&:to_s) - HIDDEN_METHODS
+ in [:lvar_or_method, name, scope]
+ scope.self_type.all_methods.map(&:to_s) | scope.local_variables | ReservedWords
+ else
+ []
+ end
+ [name || '', candidates]
+ end
+
+ def analyze(code, binding = Object::TOPLEVEL_BINDING)
+ # Workaround for https://github.com/ruby/prism/issues/1592
+ return if code.match?(/%[qQ]\z/)
+
+ ast = Prism.parse(code, scopes: [binding.local_variables]).value
+ name = code[/(@@|@|\$)?\w*[!?=]?\z/]
+ *parents, target_node = find_target ast, code.bytesize - name.bytesize
+ return unless target_node
+
+ calculate_scope = -> { TypeAnalyzer.calculate_target_type_scope(binding, parents, target_node).last }
+ calculate_type_scope = ->(node) { TypeAnalyzer.calculate_target_type_scope binding, [*parents, target_node], node }
+
+ case target_node
+ when Prism::StringNode, Prism::InterpolatedStringNode
+ call_node, args_node = parents.last(2)
+ return unless call_node.is_a?(Prism::CallNode) && call_node.receiver.nil?
+ return unless args_node.is_a?(Prism::ArgumentsNode) && args_node.arguments.size == 1
+
+ case call_node.name
+ when :require
+ [:require, name.rstrip]
+ when :require_relative
+ [:require_relative, name.rstrip]
+ end
+ when Prism::SymbolNode
+ if parents.last.is_a? Prism::BlockArgumentNode # method(&:target)
+ receiver_type, _scope = calculate_type_scope.call target_node
+ [:call, receiver_type, name, false]
+ else
+ [:symbol, name] unless name.empty?
+ end
+ when Prism::CallNode
+ return [:lvar_or_method, name, calculate_scope.call] if target_node.receiver.nil?
+
+ self_call = target_node.receiver.is_a? Prism::SelfNode
+ op = target_node.call_operator
+ receiver_type, _scope = calculate_type_scope.call target_node.receiver
+ receiver_type = receiver_type.nonnillable if op == '&.'
+ [op == '::' ? :call_or_const : :call, receiver_type, name, self_call]
+ when Prism::LocalVariableReadNode, Prism::LocalVariableTargetNode
+ [:lvar_or_method, name, calculate_scope.call]
+ when Prism::ConstantReadNode, Prism::ConstantTargetNode
+ if parents.last.is_a? Prism::ConstantPathNode
+ path_node = parents.last
+ if path_node.parent # A::B
+ receiver, scope = calculate_type_scope.call(path_node.parent)
+ [:const, receiver, name, scope]
+ else # ::A
+ scope = calculate_scope.call
+ [:const, Types::SingletonType.new(Object), name, scope]
+ end
+ else
+ [:const, nil, name, calculate_scope.call]
+ end
+ when Prism::GlobalVariableReadNode, Prism::GlobalVariableTargetNode
+ [:gvar, name, calculate_scope.call]
+ when Prism::InstanceVariableReadNode, Prism::InstanceVariableTargetNode
+ [:ivar, name, calculate_scope.call]
+ when Prism::ClassVariableReadNode, Prism::ClassVariableTargetNode
+ [:cvar, name, calculate_scope.call]
+ end
+ end
+
+ def find_target(node, position)
+ location = (
+ case node
+ when Prism::CallNode
+ node.message_loc
+ when Prism::SymbolNode
+ node.value_loc
+ when Prism::StringNode
+ node.content_loc
+ when Prism::InterpolatedStringNode
+ node.closing_loc if node.parts.empty?
+ end
+ )
+ return [node] if location&.start_offset == position
+
+ node.compact_child_nodes.each do |n|
+ match = find_target(n, position)
+ next unless match
+ match.unshift node
+ return match
+ end
+
+ [node] if node.location.start_offset == position
+ end
+
+ def handle_error(e)
+ end
+ end
+ end
+end
diff --git a/lib/irb/type_completion/methods.rb b/lib/irb/type_completion/methods.rb
new file mode 100644
index 0000000000..8a88b6d0f9
--- /dev/null
+++ b/lib/irb/type_completion/methods.rb
@@ -0,0 +1,13 @@
+# frozen_string_literal: true
+
+module IRB
+ module TypeCompletion
+ module Methods
+ OBJECT_SINGLETON_CLASS_METHOD = Object.instance_method(:singleton_class)
+ OBJECT_INSTANCE_VARIABLES_METHOD = Object.instance_method(:instance_variables)
+ OBJECT_INSTANCE_VARIABLE_GET_METHOD = Object.instance_method(:instance_variable_get)
+ OBJECT_CLASS_METHOD = Object.instance_method(:class)
+ MODULE_NAME_METHOD = Module.instance_method(:name)
+ end
+ end
+end
diff --git a/lib/irb/type_completion/scope.rb b/lib/irb/type_completion/scope.rb
new file mode 100644
index 0000000000..5a58a0ed65
--- /dev/null
+++ b/lib/irb/type_completion/scope.rb
@@ -0,0 +1,412 @@
+# frozen_string_literal: true
+
+require 'set'
+require_relative 'types'
+
+module IRB
+ module TypeCompletion
+
+ class RootScope
+ attr_reader :module_nesting, :self_object
+
+ def initialize(binding, self_object, local_variables)
+ @binding = binding
+ @self_object = self_object
+ @cache = {}
+ modules = [*binding.eval('::Module.nesting'), Object]
+ @module_nesting = modules.map { [_1, []] }
+ binding_local_variables = binding.local_variables
+ uninitialized_locals = local_variables - binding_local_variables
+ uninitialized_locals.each { @cache[_1] = Types::NIL }
+ @local_variables = (local_variables | binding_local_variables).map(&:to_s).to_set
+ @global_variables = Kernel.global_variables.map(&:to_s).to_set
+ @owned_constants_cache = {}
+ end
+
+ def level() = 0
+
+ def level_of(_name, _var_type) = 0
+
+ def mutable?() = false
+
+ def module_own_constant?(mod, name)
+ set = (@owned_constants_cache[mod] ||= Set.new(mod.constants.map(&:to_s)))
+ set.include? name
+ end
+
+ def get_const(nesting, path, _key = nil)
+ return unless nesting
+
+ result = path.reduce nesting do |mod, name|
+ return nil unless mod.is_a?(Module) && module_own_constant?(mod, name)
+ mod.const_get name
+ end
+ Types.type_from_object result
+ end
+
+ def get_cvar(nesting, path, name, _key = nil)
+ return Types::NIL unless nesting
+
+ result = path.reduce nesting do |mod, n|
+ return Types::NIL unless mod.is_a?(Module) && module_own_constant?(mod, n)
+ mod.const_get n
+ end
+ value = result.class_variable_get name if result.is_a?(Module) && name.size >= 3 && result.class_variable_defined?(name)
+ Types.type_from_object value
+ end
+
+ def [](name)
+ @cache[name] ||= (
+ value = case RootScope.type_by_name name
+ when :ivar
+ begin
+ Methods::OBJECT_INSTANCE_VARIABLE_GET_METHOD.bind_call(@self_object, name)
+ rescue NameError
+ end
+ when :lvar
+ begin
+ @binding.local_variable_get(name)
+ rescue NameError
+ end
+ when :gvar
+ @binding.eval name if @global_variables.include? name
+ end
+ Types.type_from_object(value)
+ )
+ end
+
+ def self_type
+ Types.type_from_object @self_object
+ end
+
+ def local_variables() = @local_variables.to_a
+
+ def global_variables() = @global_variables.to_a
+
+ def self.type_by_name(name)
+ if name.start_with? '@@'
+ # "@@cvar" or "@@cvar::[module_id]::[module_path]"
+ :cvar
+ elsif name.start_with? '@'
+ :ivar
+ elsif name.start_with? '$'
+ :gvar
+ elsif name.start_with? '%'
+ :internal
+ elsif name[0].downcase != name[0] || name[0].match?(/\d/)
+ # "ConstName" or "[module_id]::[const_path]"
+ :const
+ else
+ :lvar
+ end
+ end
+ end
+
+ class Scope
+ BREAK_RESULT = '%break'
+ NEXT_RESULT = '%next'
+ RETURN_RESULT = '%return'
+ PATTERNMATCH_BREAK = '%match'
+
+ attr_reader :parent, :mergeable_changes, :level, :module_nesting
+
+ def self.from_binding(binding, locals) = new(RootScope.new(binding, binding.receiver, locals))
+
+ def initialize(parent, table = {}, trace_ivar: true, trace_lvar: true, self_type: nil, nesting: nil)
+ @parent = parent
+ @level = parent.level + 1
+ @trace_ivar = trace_ivar
+ @trace_lvar = trace_lvar
+ @module_nesting = nesting ? [nesting, *parent.module_nesting] : parent.module_nesting
+ @self_type = self_type
+ @terminated = false
+ @jump_branches = []
+ @mergeable_changes = @table = table.transform_values { [level, _1] }
+ end
+
+ def mutable? = true
+
+ def terminated?
+ @terminated
+ end
+
+ def terminate_with(type, value)
+ return if terminated?
+ store_jump type, value, @mergeable_changes
+ terminate
+ end
+
+ def store_jump(type, value, changes)
+ return if terminated?
+ if has_own?(type)
+ changes[type] = [level, value]
+ @jump_branches << changes
+ elsif @parent.mutable?
+ @parent.store_jump(type, value, changes)
+ end
+ end
+
+ def terminate
+ return if terminated?
+ @terminated = true
+ @table = @mergeable_changes.dup
+ end
+
+ def trace?(name)
+ return false unless @parent
+ type = RootScope.type_by_name(name)
+ type == :ivar ? @trace_ivar : type == :lvar ? @trace_lvar : true
+ end
+
+ def level_of(name, var_type)
+ case var_type
+ when :ivar
+ return level unless @trace_ivar
+ when :gvar
+ return 0
+ end
+ variable_level, = @table[name]
+ variable_level || parent.level_of(name, var_type)
+ end
+
+ def get_const(nesting, path, key = nil)
+ key ||= [nesting.__id__, path].join('::')
+ _l, value = @table[key]
+ value || @parent.get_const(nesting, path, key)
+ end
+
+ def get_cvar(nesting, path, name, key = nil)
+ key ||= [name, nesting.__id__, path].join('::')
+ _l, value = @table[key]
+ value || @parent.get_cvar(nesting, path, name, key)
+ end
+
+ def [](name)
+ type = RootScope.type_by_name(name)
+ if type == :const
+ return get_const(nil, nil, name) || Types::NIL if name.include?('::')
+
+ module_nesting.each do |(nesting, path)|
+ value = get_const nesting, [*path, name]
+ return value if value
+ end
+ return Types::NIL
+ elsif type == :cvar
+ return get_cvar(nil, nil, nil, name) if name.include?('::')
+
+ nesting, path = module_nesting.first
+ return get_cvar(nesting, path, name)
+ end
+ level, value = @table[name]
+ if level
+ value
+ elsif trace? name
+ @parent[name]
+ elsif type == :ivar
+ self_instance_variable_get name
+ end
+ end
+
+ def set_const(nesting, path, value)
+ key = [nesting.__id__, path].join('::')
+ @table[key] = [0, value]
+ end
+
+ def set_cvar(nesting, path, name, value)
+ key = [name, nesting.__id__, path].join('::')
+ @table[key] = [0, value]
+ end
+
+ def []=(name, value)
+ type = RootScope.type_by_name(name)
+ if type == :const
+ if name.include?('::')
+ @table[name] = [0, value]
+ else
+ parent_module, parent_path = module_nesting.first
+ set_const parent_module, [*parent_path, name], value
+ end
+ return
+ elsif type == :cvar
+ if name.include?('::')
+ @table[name] = [0, value]
+ else
+ parent_module, parent_path = module_nesting.first
+ set_cvar parent_module, parent_path, name, value
+ end
+ return
+ end
+ variable_level = level_of name, type
+ @table[name] = [variable_level, value] if variable_level
+ end
+
+ def self_type
+ @self_type || @parent.self_type
+ end
+
+ def global_variables
+ gvar_keys = @table.keys.select do |name|
+ RootScope.type_by_name(name) == :gvar
+ end
+ gvar_keys | @parent.global_variables
+ end
+
+ def local_variables
+ lvar_keys = @table.keys.select do |name|
+ RootScope.type_by_name(name) == :lvar
+ end
+ lvar_keys |= @parent.local_variables if @trace_lvar
+ lvar_keys
+ end
+
+ def table_constants
+ constants = module_nesting.flat_map do |mod, path|
+ prefix = [mod.__id__, *path].join('::') + '::'
+ @table.keys.select { _1.start_with? prefix }.map { _1.delete_prefix(prefix).split('::').first }
+ end.uniq
+ constants |= @parent.table_constants if @parent.mutable?
+ constants
+ end
+
+ def table_module_constants(mod)
+ prefix = "#{mod.__id__}::"
+ constants = @table.keys.select { _1.start_with? prefix }.map { _1.delete_prefix(prefix).split('::').first }
+ constants |= @parent.table_constants if @parent.mutable?
+ constants
+ end
+
+ def base_scope
+ @parent.mutable? ? @parent.base_scope : @parent
+ end
+
+ def table_instance_variables
+ ivars = @table.keys.select { RootScope.type_by_name(_1) == :ivar }
+ ivars |= @parent.table_instance_variables if @parent.mutable? && @trace_ivar
+ ivars
+ end
+
+ def instance_variables
+ self_singleton_types = self_type.types.grep(Types::SingletonType)
+ singleton_classes = self_type.types.grep(Types::InstanceType).map(&:klass).select(&:singleton_class?)
+ base_self = base_scope.self_object
+ self_instance_variables = singleton_classes.flat_map do |singleton_class|
+ if singleton_class.respond_to? :attached_object
+ Methods::OBJECT_INSTANCE_VARIABLES_METHOD.bind_call(singleton_class.attached_object).map(&:to_s)
+ elsif singleton_class == Methods::OBJECT_SINGLETON_CLASS_METHOD.bind_call(base_self)
+ Methods::OBJECT_INSTANCE_VARIABLES_METHOD.bind_call(base_self).map(&:to_s)
+ else
+ []
+ end
+ end
+ [
+ self_singleton_types.flat_map { _1.module_or_class.instance_variables.map(&:to_s) },
+ self_instance_variables || [],
+ table_instance_variables
+ ].inject(:|)
+ end
+
+ def self_instance_variable_get(name)
+ self_objects = self_type.types.grep(Types::SingletonType).map(&:module_or_class)
+ singleton_classes = self_type.types.grep(Types::InstanceType).map(&:klass).select(&:singleton_class?)
+ base_self = base_scope.self_object
+ singleton_classes.each do |singleton_class|
+ if singleton_class.respond_to? :attached_object
+ self_objects << singleton_class.attached_object
+ elsif singleton_class == base_self.singleton_class
+ self_objects << base_self
+ end
+ end
+ types = self_objects.map do |object|
+ value = begin
+ Methods::OBJECT_INSTANCE_VARIABLE_GET_METHOD.bind_call(object, name)
+ rescue NameError
+ end
+ Types.type_from_object value
+ end
+ Types::UnionType[*types]
+ end
+
+ def table_class_variables
+ cvars = @table.keys.filter_map { _1.split('::', 2).first if RootScope.type_by_name(_1) == :cvar }
+ cvars |= @parent.table_class_variables if @parent.mutable?
+ cvars
+ end
+
+ def class_variables
+ cvars = table_class_variables
+ m, = module_nesting.first
+ cvars |= m.class_variables.map(&:to_s) if m.is_a? Module
+ cvars
+ end
+
+ def constants
+ module_nesting.flat_map do |nest,|
+ nest.constants
+ end.map(&:to_s) | table_constants
+ end
+
+ def merge_jumps
+ if terminated?
+ @terminated = false
+ @table = @mergeable_changes
+ merge @jump_branches
+ @terminated = true
+ else
+ merge [*@jump_branches, {}]
+ end
+ end
+
+ def conditional(&block)
+ run_branches(block, ->(_s) {}).first || Types::NIL
+ end
+
+ def never(&block)
+ block.call Scope.new(self, { BREAK_RESULT => nil, NEXT_RESULT => nil, PATTERNMATCH_BREAK => nil, RETURN_RESULT => nil })
+ end
+
+ def run_branches(*blocks)
+ results = []
+ branches = []
+ blocks.each do |block|
+ scope = Scope.new self
+ result = block.call scope
+ next if scope.terminated?
+ results << result
+ branches << scope.mergeable_changes
+ end
+ terminate if branches.empty?
+ merge branches
+ results
+ end
+
+ def has_own?(name)
+ @table.key? name
+ end
+
+ def update(child_scope)
+ current_level = level
+ child_scope.mergeable_changes.each do |name, (level, value)|
+ self[name] = value if level <= current_level
+ end
+ end
+
+ protected
+
+ def merge(branches)
+ current_level = level
+ merge = {}
+ branches.each do |changes|
+ changes.each do |name, (level, value)|
+ next if current_level < level
+ (merge[name] ||= []) << value
+ end
+ end
+ merge.each do |name, values|
+ values << self[name] unless values.size == branches.size
+ values.compact!
+ self[name] = Types::UnionType[*values.compact] unless values.empty?
+ end
+ end
+ end
+ end
+end
diff --git a/lib/irb/type_completion/type_analyzer.rb b/lib/irb/type_completion/type_analyzer.rb
new file mode 100644
index 0000000000..c4a41e4999
--- /dev/null
+++ b/lib/irb/type_completion/type_analyzer.rb
@@ -0,0 +1,1169 @@
+# frozen_string_literal: true
+
+require 'set'
+require_relative 'types'
+require_relative 'scope'
+require 'prism'
+
+module IRB
+ module TypeCompletion
+ class TypeAnalyzer
+ class DigTarget
+ def initialize(parents, receiver, &block)
+ @dig_ids = parents.to_h { [_1.__id__, true] }
+ @target_id = receiver.__id__
+ @block = block
+ end
+
+ def dig?(node) = @dig_ids[node.__id__]
+ def target?(node) = @target_id == node.__id__
+ def resolve(type, scope)
+ @block.call type, scope
+ end
+ end
+
+ OBJECT_METHODS = {
+ to_s: Types::STRING,
+ to_str: Types::STRING,
+ to_a: Types::ARRAY,
+ to_ary: Types::ARRAY,
+ to_h: Types::HASH,
+ to_hash: Types::HASH,
+ to_i: Types::INTEGER,
+ to_int: Types::INTEGER,
+ to_f: Types::FLOAT,
+ to_c: Types::COMPLEX,
+ to_r: Types::RATIONAL
+ }
+
+ def initialize(dig_targets)
+ @dig_targets = dig_targets
+ end
+
+ def evaluate(node, scope)
+ method = "evaluate_#{node.type}"
+ if respond_to? method
+ result = send method, node, scope
+ else
+ result = Types::NIL
+ end
+ @dig_targets.resolve result, scope if @dig_targets.target? node
+ result
+ end
+
+ def evaluate_program_node(node, scope)
+ evaluate node.statements, scope
+ end
+
+ def evaluate_statements_node(node, scope)
+ if node.body.empty?
+ Types::NIL
+ else
+ node.body.map { evaluate _1, scope }.last
+ end
+ end
+
+ def evaluate_def_node(node, scope)
+ if node.receiver
+ self_type = evaluate node.receiver, scope
+ else
+ current_self_types = scope.self_type.types
+ self_types = current_self_types.map do |type|
+ if type.is_a?(Types::SingletonType) && type.module_or_class.is_a?(Class)
+ Types::InstanceType.new type.module_or_class
+ else
+ type
+ end
+ end
+ self_type = Types::UnionType[*self_types]
+ end
+ if @dig_targets.dig?(node.body) || @dig_targets.dig?(node.parameters)
+ params_table = node.locals.to_h { [_1.to_s, Types::NIL] }
+ method_scope = Scope.new(
+ scope,
+ { **params_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil },
+ self_type: self_type,
+ trace_lvar: false,
+ trace_ivar: false
+ )
+ if node.parameters
+ # node.parameters is Prism::ParametersNode
+ assign_parameters node.parameters, method_scope, [], {}
+ end
+
+ if @dig_targets.dig?(node.body)
+ method_scope.conditional do |s|
+ evaluate node.body, s
+ end
+ end
+ method_scope.merge_jumps
+ scope.update method_scope
+ end
+ Types::SYMBOL
+ end
+
+ def evaluate_integer_node(_node, _scope) = Types::INTEGER
+
+ def evaluate_float_node(_node, _scope) = Types::FLOAT
+
+ def evaluate_rational_node(_node, _scope) = Types::RATIONAL
+
+ def evaluate_imaginary_node(_node, _scope) = Types::COMPLEX
+
+ def evaluate_string_node(_node, _scope) = Types::STRING
+
+ def evaluate_x_string_node(_node, _scope)
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_symbol_node(_node, _scope) = Types::SYMBOL
+
+ def evaluate_regular_expression_node(_node, _scope) = Types::REGEXP
+
+ def evaluate_string_concat_node(node, scope)
+ evaluate node.left, scope
+ evaluate node.right, scope
+ Types::STRING
+ end
+
+ def evaluate_interpolated_string_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::STRING
+ end
+
+ def evaluate_interpolated_x_string_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::STRING
+ end
+
+ def evaluate_interpolated_symbol_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::SYMBOL
+ end
+
+ def evaluate_interpolated_regular_expression_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::REGEXP
+ end
+
+ def evaluate_embedded_statements_node(node, scope)
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ Types::STRING
+ end
+
+ def evaluate_embedded_variable_node(node, scope)
+ evaluate node.variable, scope
+ Types::STRING
+ end
+
+ def evaluate_array_node(node, scope)
+ Types.array_of evaluate_list_splat_items(node.elements, scope)
+ end
+
+ def evaluate_hash_node(node, scope) = evaluate_hash(node, scope)
+ def evaluate_keyword_hash_node(node, scope) = evaluate_hash(node, scope)
+ def evaluate_hash(node, scope)
+ keys = []
+ values = []
+ node.elements.each do |assoc|
+ case assoc
+ when Prism::AssocNode
+ keys << evaluate(assoc.key, scope)
+ values << evaluate(assoc.value, scope)
+ when Prism::AssocSplatNode
+ next unless assoc.value # def f(**); {**}
+
+ hash = evaluate assoc.value, scope
+ unless hash.is_a?(Types::InstanceType) && hash.klass == Hash
+ hash = method_call hash, :to_hash, [], nil, nil, scope
+ end
+ if hash.is_a?(Types::InstanceType) && hash.klass == Hash
+ keys << hash.params[:K] if hash.params[:K]
+ values << hash.params[:V] if hash.params[:V]
+ end
+ end
+ end
+ if keys.empty? && values.empty?
+ Types::InstanceType.new Hash
+ else
+ Types::InstanceType.new Hash, K: Types::UnionType[*keys], V: Types::UnionType[*values]
+ end
+ end
+
+ def evaluate_parentheses_node(node, scope)
+ node.body ? evaluate(node.body, scope) : Types::NIL
+ end
+
+ def evaluate_constant_path_node(node, scope)
+ type, = evaluate_constant_node_info node, scope
+ type
+ end
+
+ def evaluate_self_node(_node, scope) = scope.self_type
+
+ def evaluate_true_node(_node, _scope) = Types::TRUE
+
+ def evaluate_false_node(_node, _scope) = Types::FALSE
+
+ def evaluate_nil_node(_node, _scope) = Types::NIL
+
+ def evaluate_source_file_node(_node, _scope) = Types::STRING
+
+ def evaluate_source_line_node(_node, _scope) = Types::INTEGER
+
+ def evaluate_source_encoding_node(_node, _scope) = Types::InstanceType.new(Encoding)
+
+ def evaluate_numbered_reference_read_node(_node, _scope)
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_back_reference_read_node(_node, _scope)
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_reference_read(node, scope)
+ scope[node.name.to_s] || Types::NIL
+ end
+ alias evaluate_constant_read_node evaluate_reference_read
+ alias evaluate_global_variable_read_node evaluate_reference_read
+ alias evaluate_local_variable_read_node evaluate_reference_read
+ alias evaluate_class_variable_read_node evaluate_reference_read
+ alias evaluate_instance_variable_read_node evaluate_reference_read
+
+
+ def evaluate_call_node(node, scope)
+ is_field_assign = node.name.match?(/[^<>=!\]]=\z/) || (node.name == :[]= && !node.call_operator)
+ receiver_type = node.receiver ? evaluate(node.receiver, scope) : scope.self_type
+ evaluate_method = lambda do |scope|
+ args_types, kwargs_types, block_sym_node, has_block = evaluate_call_node_arguments node, scope
+
+ if block_sym_node
+ block_sym = block_sym_node.value
+ if @dig_targets.target? block_sym_node
+ # method(args, &:completion_target)
+ call_block_proc = ->(block_args, _self_type) do
+ block_receiver = block_args.first || Types::OBJECT
+ @dig_targets.resolve block_receiver, scope
+ Types::OBJECT
+ end
+ else
+ call_block_proc = ->(block_args, _self_type) do
+ block_receiver, *rest = block_args
+ block_receiver ? method_call(block_receiver || Types::OBJECT, block_sym, rest, nil, nil, scope) : Types::OBJECT
+ end
+ end
+ elsif node.block.is_a? Prism::BlockNode
+ call_block_proc = ->(block_args, block_self_type) do
+ scope.conditional do |s|
+ numbered_parameters = node.block.locals.grep(/\A_[1-9]/).map(&:to_s)
+ params_table = node.block.locals.to_h { [_1.to_s, Types::NIL] }
+ table = { **params_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil }
+ block_scope = Scope.new s, table, self_type: block_self_type, trace_ivar: !block_self_type
+ # TODO kwargs
+ if node.block.parameters&.parameters
+ # node.block.parameters is Prism::BlockParametersNode
+ assign_parameters node.block.parameters.parameters, block_scope, block_args, {}
+ elsif !numbered_parameters.empty?
+ assign_numbered_parameters numbered_parameters, block_scope, block_args, {}
+ end
+ result = node.block.body ? evaluate(node.block.body, block_scope) : Types::NIL
+ block_scope.merge_jumps
+ s.update block_scope
+ nexts = block_scope[Scope::NEXT_RESULT]
+ breaks = block_scope[Scope::BREAK_RESULT]
+ if block_scope.terminated?
+ [Types::UnionType[*nexts], breaks]
+ else
+ [Types::UnionType[result, *nexts], breaks]
+ end
+ end
+ end
+ elsif has_block
+ call_block_proc = ->(_block_args, _self_type) { Types::OBJECT }
+ end
+ result = method_call receiver_type, node.name, args_types, kwargs_types, call_block_proc, scope
+ if is_field_assign
+ args_types.last || Types::NIL
+ else
+ result
+ end
+ end
+ if node.call_operator == '&.'
+ result = scope.conditional { evaluate_method.call _1 }
+ if receiver_type.nillable?
+ Types::UnionType[result, Types::NIL]
+ else
+ result
+ end
+ else
+ evaluate_method.call scope
+ end
+ end
+
+ def evaluate_and_node(node, scope) = evaluate_and_or(node, scope, and_op: true)
+ def evaluate_or_node(node, scope) = evaluate_and_or(node, scope, and_op: false)
+ def evaluate_and_or(node, scope, and_op:)
+ left = evaluate node.left, scope
+ right = scope.conditional { evaluate node.right, _1 }
+ if and_op
+ Types::UnionType[right, Types::NIL, Types::FALSE]
+ else
+ Types::UnionType[left, right]
+ end
+ end
+
+ def evaluate_call_operator_write_node(node, scope) = evaluate_call_write(node, scope, :operator, node.write_name)
+ def evaluate_call_and_write_node(node, scope) = evaluate_call_write(node, scope, :and, node.write_name)
+ def evaluate_call_or_write_node(node, scope) = evaluate_call_write(node, scope, :or, node.write_name)
+ def evaluate_index_operator_write_node(node, scope) = evaluate_call_write(node, scope, :operator, :[]=)
+ def evaluate_index_and_write_node(node, scope) = evaluate_call_write(node, scope, :and, :[]=)
+ def evaluate_index_or_write_node(node, scope) = evaluate_call_write(node, scope, :or, :[]=)
+ def evaluate_call_write(node, scope, operator, write_name)
+ receiver_type = evaluate node.receiver, scope
+ if write_name == :[]=
+ args_types, kwargs_types, block_sym_node, has_block = evaluate_call_node_arguments node, scope
+ else
+ args_types = []
+ end
+ if block_sym_node
+ block_sym = block_sym_node.value
+ call_block_proc = ->(block_args, _self_type) do
+ block_receiver, *rest = block_args
+ block_receiver ? method_call(block_receiver || Types::OBJECT, block_sym, rest, nil, nil, scope) : Types::OBJECT
+ end
+ elsif has_block
+ call_block_proc = ->(_block_args, _self_type) { Types::OBJECT }
+ end
+ method = write_name.to_s.delete_suffix('=')
+ left = method_call receiver_type, method, args_types, kwargs_types, call_block_proc, scope
+ case operator
+ when :and
+ right = scope.conditional { evaluate node.value, _1 }
+ Types::UnionType[right, Types::NIL, Types::FALSE]
+ when :or
+ right = scope.conditional { evaluate node.value, _1 }
+ Types::UnionType[left, right]
+ else
+ right = evaluate node.value, scope
+ method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ end
+ end
+
+ def evaluate_variable_operator_write(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = evaluate node.value, scope
+ scope[node.name.to_s] = method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ end
+ alias evaluate_global_variable_operator_write_node evaluate_variable_operator_write
+ alias evaluate_local_variable_operator_write_node evaluate_variable_operator_write
+ alias evaluate_class_variable_operator_write_node evaluate_variable_operator_write
+ alias evaluate_instance_variable_operator_write_node evaluate_variable_operator_write
+
+ def evaluate_variable_and_write(node, scope)
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[right, Types::NIL, Types::FALSE]
+ end
+ alias evaluate_global_variable_and_write_node evaluate_variable_and_write
+ alias evaluate_local_variable_and_write_node evaluate_variable_and_write
+ alias evaluate_class_variable_and_write_node evaluate_variable_and_write
+ alias evaluate_instance_variable_and_write_node evaluate_variable_and_write
+
+ def evaluate_variable_or_write(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[left, right]
+ end
+ alias evaluate_global_variable_or_write_node evaluate_variable_or_write
+ alias evaluate_local_variable_or_write_node evaluate_variable_or_write
+ alias evaluate_class_variable_or_write_node evaluate_variable_or_write
+ alias evaluate_instance_variable_or_write_node evaluate_variable_or_write
+
+ def evaluate_constant_operator_write_node(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = evaluate node.value, scope
+ scope[node.name.to_s] = method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ end
+
+ def evaluate_constant_and_write_node(node, scope)
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[right, Types::NIL, Types::FALSE]
+ end
+
+ def evaluate_constant_or_write_node(node, scope)
+ left = scope[node.name.to_s] || Types::OBJECT
+ right = scope.conditional { evaluate node.value, scope }
+ scope[node.name.to_s] = Types::UnionType[left, right]
+ end
+
+ def evaluate_constant_path_operator_write_node(node, scope)
+ left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope
+ right = evaluate node.value, scope
+ value = method_call left, node.operator, [right], nil, nil, scope, name_match: false
+ const_path_write receiver, name, value, scope
+ value
+ end
+
+ def evaluate_constant_path_and_write_node(node, scope)
+ _left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope
+ right = scope.conditional { evaluate node.value, scope }
+ value = Types::UnionType[right, Types::NIL, Types::FALSE]
+ const_path_write receiver, name, value, scope
+ value
+ end
+
+ def evaluate_constant_path_or_write_node(node, scope)
+ left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope
+ right = scope.conditional { evaluate node.value, scope }
+ value = Types::UnionType[left, right]
+ const_path_write receiver, name, value, scope
+ value
+ end
+
+ def evaluate_constant_path_write_node(node, scope)
+ receiver = evaluate node.target.parent, scope if node.target.parent
+ value = evaluate node.value, scope
+ const_path_write receiver, node.target.child.name.to_s, value, scope
+ value
+ end
+
+ def evaluate_lambda_node(node, scope)
+ local_table = node.locals.to_h { [_1.to_s, Types::OBJECT] }
+ block_scope = Scope.new scope, { **local_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil }
+ block_scope.conditional do |s|
+ assign_parameters node.parameters.parameters, s, [], {} if node.parameters&.parameters
+ evaluate node.body, s if node.body
+ end
+ block_scope.merge_jumps
+ scope.update block_scope
+ Types::PROC
+ end
+
+ def evaluate_reference_write(node, scope)
+ scope[node.name.to_s] = evaluate node.value, scope
+ end
+ alias evaluate_constant_write_node evaluate_reference_write
+ alias evaluate_global_variable_write_node evaluate_reference_write
+ alias evaluate_local_variable_write_node evaluate_reference_write
+ alias evaluate_class_variable_write_node evaluate_reference_write
+ alias evaluate_instance_variable_write_node evaluate_reference_write
+
+ def evaluate_multi_write_node(node, scope)
+ evaluated_receivers = {}
+ evaluate_multi_write_receiver node, scope, evaluated_receivers
+ value = (
+ if node.value.is_a? Prism::ArrayNode
+ if node.value.elements.any?(Prism::SplatNode)
+ evaluate node.value, scope
+ else
+ node.value.elements.map do |n|
+ evaluate n, scope
+ end
+ end
+ elsif node.value
+ evaluate node.value, scope
+ else
+ Types::NIL
+ end
+ )
+ evaluate_multi_write node, value, scope, evaluated_receivers
+ value.is_a?(Array) ? Types.array_of(*value) : value
+ end
+
+ def evaluate_if_node(node, scope) = evaluate_if_unless(node, scope)
+ def evaluate_unless_node(node, scope) = evaluate_if_unless(node, scope)
+ def evaluate_if_unless(node, scope)
+ evaluate node.predicate, scope
+ Types::UnionType[*scope.run_branches(
+ -> { node.statements ? evaluate(node.statements, _1) : Types::NIL },
+ -> { node.consequent ? evaluate(node.consequent, _1) : Types::NIL }
+ )]
+ end
+
+ def evaluate_else_node(node, scope)
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ end
+
+ def evaluate_while_until(node, scope)
+ inner_scope = Scope.new scope, { Scope::BREAK_RESULT => nil }
+ evaluate node.predicate, inner_scope
+ if node.statements
+ inner_scope.conditional do |s|
+ evaluate node.statements, s
+ end
+ end
+ inner_scope.merge_jumps
+ scope.update inner_scope
+ breaks = inner_scope[Scope::BREAK_RESULT]
+ breaks ? Types::UnionType[breaks, Types::NIL] : Types::NIL
+ end
+ alias evaluate_while_node evaluate_while_until
+ alias evaluate_until_node evaluate_while_until
+
+ def evaluate_break_node(node, scope) = evaluate_jump(node, scope, :break)
+ def evaluate_next_node(node, scope) = evaluate_jump(node, scope, :next)
+ def evaluate_return_node(node, scope) = evaluate_jump(node, scope, :return)
+ def evaluate_jump(node, scope, mode)
+ internal_key = (
+ case mode
+ when :break
+ Scope::BREAK_RESULT
+ when :next
+ Scope::NEXT_RESULT
+ when :return
+ Scope::RETURN_RESULT
+ end
+ )
+ jump_value = (
+ arguments = node.arguments&.arguments
+ if arguments.nil? || arguments.empty?
+ Types::NIL
+ elsif arguments.size == 1 && !arguments.first.is_a?(Prism::SplatNode)
+ evaluate arguments.first, scope
+ else
+ Types.array_of evaluate_list_splat_items(arguments, scope)
+ end
+ )
+ scope.terminate_with internal_key, jump_value
+ Types::NIL
+ end
+
+ def evaluate_yield_node(node, scope)
+ evaluate_list_splat_items node.arguments.arguments, scope if node.arguments
+ Types::OBJECT
+ end
+
+ def evaluate_redo_node(_node, scope)
+ scope.terminate
+ Types::NIL
+ end
+
+ def evaluate_retry_node(_node, scope)
+ scope.terminate
+ Types::NIL
+ end
+
+ def evaluate_forwarding_super_node(_node, _scope) = Types::OBJECT
+
+ def evaluate_super_node(node, scope)
+ evaluate_list_splat_items node.arguments.arguments, scope if node.arguments
+ Types::OBJECT
+ end
+
+ def evaluate_begin_node(node, scope)
+ return_type = node.statements ? evaluate(node.statements, scope) : Types::NIL
+ if node.rescue_clause
+ if node.else_clause
+ return_types = scope.run_branches(
+ ->{ evaluate node.rescue_clause, _1 },
+ ->{ evaluate node.else_clause, _1 }
+ )
+ else
+ return_types = [
+ return_type,
+ scope.conditional { evaluate node.rescue_clause, _1 }
+ ]
+ end
+ return_type = Types::UnionType[*return_types]
+ end
+ if node.ensure_clause&.statements
+ # ensure_clause is Prism::EnsureNode
+ evaluate node.ensure_clause.statements, scope
+ end
+ return_type
+ end
+
+ def evaluate_rescue_node(node, scope)
+ run_rescue = lambda do |s|
+ if node.reference
+ error_classes_type = evaluate_list_splat_items node.exceptions, s
+ error_types = error_classes_type.types.filter_map do
+ Types::InstanceType.new _1.module_or_class if _1.is_a?(Types::SingletonType)
+ end
+ error_types << Types::InstanceType.new(StandardError) if error_types.empty?
+ error_type = Types::UnionType[*error_types]
+ case node.reference
+ when Prism::LocalVariableTargetNode, Prism::InstanceVariableTargetNode, Prism::ClassVariableTargetNode, Prism::GlobalVariableTargetNode, Prism::ConstantTargetNode
+ s[node.reference.name.to_s] = error_type
+ when Prism::CallNode
+ evaluate node.reference, s
+ end
+ end
+ node.statements ? evaluate(node.statements, s) : Types::NIL
+ end
+ if node.consequent # begin; rescue A; rescue B; end
+ types = scope.run_branches(
+ run_rescue,
+ -> { evaluate node.consequent, _1 }
+ )
+ Types::UnionType[*types]
+ else
+ run_rescue.call scope
+ end
+ end
+
+ def evaluate_rescue_modifier_node(node, scope)
+ a = evaluate node.expression, scope
+ b = scope.conditional { evaluate node.rescue_expression, _1 }
+ Types::UnionType[a, b]
+ end
+
+ def evaluate_singleton_class_node(node, scope)
+ klass_types = evaluate(node.expression, scope).types.filter_map do |type|
+ Types::SingletonType.new type.klass if type.is_a? Types::InstanceType
+ end
+ klass_types = [Types::CLASS] if klass_types.empty?
+ table = node.locals.to_h { [_1.to_s, Types::NIL] }
+ sclass_scope = Scope.new(
+ scope,
+ { **table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil },
+ trace_ivar: false,
+ trace_lvar: false,
+ self_type: Types::UnionType[*klass_types]
+ )
+ result = node.body ? evaluate(node.body, sclass_scope) : Types::NIL
+ scope.update sclass_scope
+ result
+ end
+
+ def evaluate_class_node(node, scope) = evaluate_class_module(node, scope, true)
+ def evaluate_module_node(node, scope) = evaluate_class_module(node, scope, false)
+ def evaluate_class_module(node, scope, is_class)
+ unless node.constant_path.is_a?(Prism::ConstantReadNode) || node.constant_path.is_a?(Prism::ConstantPathNode)
+ # Incomplete class/module `class (statement[cursor_here])::Name; end`
+ evaluate node.constant_path, scope
+ return Types::NIL
+ end
+ const_type, _receiver, parent_module, name = evaluate_constant_node_info node.constant_path, scope
+ if is_class
+ select_class_type = -> { _1.is_a?(Types::SingletonType) && _1.module_or_class.is_a?(Class) }
+ module_types = const_type.types.select(&select_class_type)
+ module_types += evaluate(node.superclass, scope).types.select(&select_class_type) if node.superclass
+ module_types << Types::CLASS if module_types.empty?
+ else
+ module_types = const_type.types.select { _1.is_a?(Types::SingletonType) && !_1.module_or_class.is_a?(Class) }
+ module_types << Types::MODULE if module_types.empty?
+ end
+ return Types::NIL unless node.body
+
+ table = node.locals.to_h { [_1.to_s, Types::NIL] }
+ if !name.empty? && (parent_module.is_a?(Module) || parent_module.nil?)
+ value = parent_module.const_get name if parent_module&.const_defined? name
+ unless value
+ value_type = scope[name]
+ value = value_type.module_or_class if value_type.is_a? Types::SingletonType
+ end
+
+ if value.is_a? Module
+ nesting = [value, []]
+ else
+ if parent_module
+ nesting = [parent_module, [name]]
+ else
+ parent_nesting, parent_path = scope.module_nesting.first
+ nesting = [parent_nesting, parent_path + [name]]
+ end
+ nesting_key = [nesting[0].__id__, nesting[1]].join('::')
+ nesting_value = is_class ? Types::CLASS : Types::MODULE
+ end
+ else
+ # parent_module == :unknown
+ # TODO: dummy module
+ end
+ module_scope = Scope.new(
+ scope,
+ { **table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil },
+ trace_ivar: false,
+ trace_lvar: false,
+ self_type: Types::UnionType[*module_types],
+ nesting: nesting
+ )
+ module_scope[nesting_key] = nesting_value if nesting_value
+ result = evaluate(node.body, module_scope)
+ scope.update module_scope
+ result
+ end
+
+ def evaluate_for_node(node, scope)
+ node.statements
+ collection = evaluate node.collection, scope
+ inner_scope = Scope.new scope, { Scope::BREAK_RESULT => nil }
+ ary_type = method_call collection, :to_ary, [], nil, nil, nil, name_match: false
+ element_types = ary_type.types.filter_map do |ary|
+ ary.params[:Elem] if ary.is_a?(Types::InstanceType) && ary.klass == Array
+ end
+ element_type = Types::UnionType[*element_types]
+ inner_scope.conditional do |s|
+ evaluate_write node.index, element_type, s, nil
+ evaluate node.statements, s if node.statements
+ end
+ inner_scope.merge_jumps
+ scope.update inner_scope
+ breaks = inner_scope[Scope::BREAK_RESULT]
+ breaks ? Types::UnionType[breaks, collection] : collection
+ end
+
+ def evaluate_case_node(node, scope)
+ target = evaluate(node.predicate, scope) if node.predicate
+ # TODO
+ branches = node.conditions.map do |condition|
+ ->(s) { evaluate_case_match target, condition, s }
+ end
+ if node.consequent
+ branches << ->(s) { evaluate node.consequent, s }
+ elsif node.conditions.any? { _1.is_a? Prism::WhenNode }
+ branches << ->(_s) { Types::NIL }
+ end
+ Types::UnionType[*scope.run_branches(*branches)]
+ end
+
+ def evaluate_match_required_node(node, scope)
+ value_type = evaluate node.value, scope
+ evaluate_match_pattern value_type, node.pattern, scope
+ Types::NIL # void value
+ end
+
+ def evaluate_match_predicate_node(node, scope)
+ value_type = evaluate node.value, scope
+ scope.conditional { evaluate_match_pattern value_type, node.pattern, _1 }
+ Types::BOOLEAN
+ end
+
+ def evaluate_range_node(node, scope)
+ beg_type = evaluate node.left, scope if node.left
+ end_type = evaluate node.right, scope if node.right
+ elem = (Types::UnionType[*[beg_type, end_type].compact]).nonnillable
+ Types::InstanceType.new Range, Elem: elem
+ end
+
+ def evaluate_defined_node(node, scope)
+ scope.conditional { evaluate node.value, _1 }
+ Types::UnionType[Types::STRING, Types::NIL]
+ end
+
+ def evaluate_flip_flop_node(node, scope)
+ scope.conditional { evaluate node.left, _1 } if node.left
+ scope.conditional { evaluate node.right, _1 } if node.right
+ Types::BOOLEAN
+ end
+
+ def evaluate_multi_target_node(node, scope)
+ # Raw MultiTargetNode, incomplete code like `a,b`, `*a`.
+ evaluate_multi_write_receiver node, scope, nil
+ Types::NIL
+ end
+
+ def evaluate_splat_node(node, scope)
+ # Raw SplatNode, incomplete code like `*a.`
+ evaluate_multi_write_receiver node.expression, scope, nil if node.expression
+ Types::NIL
+ end
+
+ def evaluate_implicit_node(node, scope)
+ evaluate node.value, scope
+ end
+
+ def evaluate_match_write_node(node, scope)
+ # /(?<a>)(?<b>)/ =~ string
+ evaluate node.call, scope
+ node.locals.each { scope[_1.to_s] = Types::UnionType[Types::STRING, Types::NIL] }
+ Types::BOOLEAN
+ end
+
+ def evaluate_match_last_line_node(_node, _scope)
+ Types::BOOLEAN
+ end
+
+ def evaluate_interpolated_match_last_line_node(node, scope)
+ node.parts.each { evaluate _1, scope }
+ Types::BOOLEAN
+ end
+
+ def evaluate_pre_execution_node(node, scope)
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ end
+
+ def evaluate_post_execution_node(node, scope)
+ node.statements && @dig_targets.dig?(node.statements) ? evaluate(node.statements, scope) : Types::NIL
+ end
+
+ def evaluate_alias_method_node(_node, _scope) = Types::NIL
+ def evaluate_alias_global_variable_node(_node, _scope) = Types::NIL
+ def evaluate_undef_node(_node, _scope) = Types::NIL
+ def evaluate_missing_node(_node, _scope) = Types::NIL
+
+ def evaluate_call_node_arguments(call_node, scope)
+ # call_node.arguments is Prism::ArgumentsNode
+ arguments = call_node.arguments&.arguments&.dup || []
+ block_arg = call_node.block.expression if call_node.block.is_a?(Prism::BlockArgumentNode)
+ kwargs = arguments.pop.elements if arguments.last.is_a?(Prism::KeywordHashNode)
+ args_types = arguments.map do |arg|
+ case arg
+ when Prism::ForwardingArgumentsNode
+ # `f(a, ...)` treat like splat
+ nil
+ when Prism::SplatNode
+ evaluate arg.expression, scope if arg.expression
+ nil # TODO: splat
+ else
+ evaluate arg, scope
+ end
+ end
+ if kwargs
+ kwargs_types = kwargs.map do |arg|
+ case arg
+ when Prism::AssocNode
+ if arg.key.is_a?(Prism::SymbolNode)
+ [arg.key.value, evaluate(arg.value, scope)]
+ else
+ evaluate arg.key, scope
+ evaluate arg.value, scope
+ nil
+ end
+ when Prism::AssocSplatNode
+ evaluate arg.value, scope if arg.value
+ nil
+ end
+ end.compact.to_h
+ end
+ if block_arg.is_a? Prism::SymbolNode
+ block_sym_node = block_arg
+ elsif block_arg
+ evaluate block_arg, scope
+ end
+ [args_types, kwargs_types, block_sym_node, !!block_arg]
+ end
+
+ def const_path_write(receiver, name, value, scope)
+ if receiver # receiver::A = value
+ singleton_type = receiver.types.find { _1.is_a? Types::SingletonType }
+ scope.set_const singleton_type.module_or_class, name, value if singleton_type
+ else # ::A = value
+ scope.set_const Object, name, value
+ end
+ end
+
+ def assign_required_parameter(node, value, scope)
+ case node
+ when Prism::RequiredParameterNode
+ scope[node.name.to_s] = value || Types::OBJECT
+ when Prism::MultiTargetNode
+ parameters = [*node.lefts, *node.rest, *node.rights]
+ values = value ? sized_splat(value, :to_ary, parameters.size) : []
+ parameters.zip values do |n, v|
+ assign_required_parameter n, v, scope
+ end
+ when Prism::SplatNode
+ splat_value = value ? Types.array_of(value) : Types::ARRAY
+ assign_required_parameter node.expression, splat_value, scope if node.expression
+ end
+ end
+
+ def evaluate_constant_node_info(node, scope)
+ case node
+ when Prism::ConstantPathNode
+ name = node.child.name.to_s
+ if node.parent
+ receiver = evaluate node.parent, scope
+ if receiver.is_a? Types::SingletonType
+ parent_module = receiver.module_or_class
+ end
+ else
+ parent_module = Object
+ end
+ if parent_module
+ type = scope.get_const(parent_module, [name]) || Types::NIL
+ else
+ parent_module = :unknown
+ type = Types::NIL
+ end
+ when Prism::ConstantReadNode
+ name = node.name.to_s
+ type = scope[name]
+ end
+ @dig_targets.resolve type, scope if @dig_targets.target? node
+ [type, receiver, parent_module, name]
+ end
+
+
+ def assign_parameters(node, scope, args, kwargs)
+ args = args.dup
+ kwargs = kwargs.dup
+ size = node.requireds.size + node.optionals.size + (node.rest ? 1 : 0) + node.posts.size
+ args = sized_splat(args.first, :to_ary, size) if size >= 2 && args.size == 1
+ reqs = args.shift node.requireds.size
+ if node.rest
+ # node.rest is Prism::RestParameterNode
+ posts = []
+ opts = args.shift node.optionals.size
+ rest = args
+ else
+ posts = args.pop node.posts.size
+ opts = args
+ rest = []
+ end
+ node.requireds.zip reqs do |n, v|
+ assign_required_parameter n, v, scope
+ end
+ node.optionals.zip opts do |n, v|
+ # n is Prism::OptionalParameterNode
+ values = [v]
+ values << evaluate(n.value, scope) if n.value
+ scope[n.name.to_s] = Types::UnionType[*values.compact]
+ end
+ node.posts.zip posts do |n, v|
+ assign_required_parameter n, v, scope
+ end
+ if node.rest&.name
+ # node.rest is Prism::RestParameterNode
+ scope[node.rest.name.to_s] = Types.array_of(*rest)
+ end
+ node.keywords.each do |n|
+ name = n.name.to_s.delete(':')
+ values = [kwargs.delete(name)]
+ # n is Prism::OptionalKeywordParameterNode (has n.value) or Prism::RequiredKeywordParameterNode (does not have n.value)
+ values << evaluate(n.value, scope) if n.respond_to?(:value)
+ scope[name] = Types::UnionType[*values.compact]
+ end
+ # node.keyword_rest is Prism::KeywordRestParameterNode or Prism::ForwardingParameterNode or Prism::NoKeywordsParameterNode
+ if node.keyword_rest.is_a?(Prism::KeywordRestParameterNode) && node.keyword_rest.name
+ scope[node.keyword_rest.name.to_s] = Types::InstanceType.new(Hash, K: Types::SYMBOL, V: Types::UnionType[*kwargs.values])
+ end
+ if node.block&.name
+ # node.block is Prism::BlockParameterNode
+ scope[node.block.name.to_s] = Types::PROC
+ end
+ end
+
+ def assign_numbered_parameters(numbered_parameters, scope, args, _kwargs)
+ return if numbered_parameters.empty?
+ max_num = numbered_parameters.map { _1[1].to_i }.max
+ if max_num == 1
+ scope['_1'] = args.first || Types::NIL
+ else
+ args = sized_splat(args.first, :to_ary, max_num) if args.size == 1
+ numbered_parameters.each do |name|
+ index = name[1].to_i - 1
+ scope[name] = args[index] || Types::NIL
+ end
+ end
+ end
+
+ def evaluate_case_match(target, node, scope)
+ case node
+ when Prism::WhenNode
+ node.conditions.each { evaluate _1, scope }
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ when Prism::InNode
+ pattern = node.pattern
+ if pattern.is_a?(Prism::IfNode) || pattern.is_a?(Prism::UnlessNode)
+ cond_node = pattern.predicate
+ pattern = pattern.statements.body.first
+ end
+ evaluate_match_pattern(target, pattern, scope)
+ evaluate cond_node, scope if cond_node # TODO: conditional branch
+ node.statements ? evaluate(node.statements, scope) : Types::NIL
+ end
+ end
+
+ def evaluate_match_pattern(value, pattern, scope)
+ # TODO: scope.terminate_with Scope::PATTERNMATCH_BREAK, Types::NIL
+ case pattern
+ when Prism::FindPatternNode
+ # TODO
+ evaluate_match_pattern Types::OBJECT, pattern.left, scope
+ pattern.requireds.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ evaluate_match_pattern Types::OBJECT, pattern.right, scope
+ when Prism::ArrayPatternNode
+ # TODO
+ pattern.requireds.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ evaluate_match_pattern Types::OBJECT, pattern.rest, scope if pattern.rest
+ pattern.posts.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ Types::ARRAY
+ when Prism::HashPatternNode
+ # TODO
+ pattern.elements.each { evaluate_match_pattern Types::OBJECT, _1, scope }
+ if pattern.respond_to?(:rest) && pattern.rest
+ evaluate_match_pattern Types::OBJECT, pattern.rest, scope
+ end
+ Types::HASH
+ when Prism::AssocNode
+ evaluate_match_pattern value, pattern.value, scope if pattern.value
+ Types::OBJECT
+ when Prism::AssocSplatNode
+ # TODO
+ evaluate_match_pattern Types::HASH, pattern.value, scope
+ Types::OBJECT
+ when Prism::PinnedVariableNode
+ evaluate pattern.variable, scope
+ when Prism::PinnedExpressionNode
+ evaluate pattern.expression, scope
+ when Prism::LocalVariableTargetNode
+ scope[pattern.name.to_s] = value
+ when Prism::AlternationPatternNode
+ Types::UnionType[evaluate_match_pattern(value, pattern.left, scope), evaluate_match_pattern(value, pattern.right, scope)]
+ when Prism::CapturePatternNode
+ capture_type = class_or_value_to_instance evaluate_match_pattern(value, pattern.value, scope)
+ value = capture_type unless capture_type.types.empty? || capture_type.types == [Types::OBJECT]
+ evaluate_match_pattern value, pattern.target, scope
+ when Prism::SplatNode
+ value = Types.array_of value
+ evaluate_match_pattern value, pattern.expression, scope if pattern.expression
+ value
+ else
+ # literal node
+ type = evaluate(pattern, scope)
+ class_or_value_to_instance(type)
+ end
+ end
+
+ def class_or_value_to_instance(type)
+ instance_types = type.types.map do |t|
+ t.is_a?(Types::SingletonType) ? Types::InstanceType.new(t.module_or_class) : t
+ end
+ Types::UnionType[*instance_types]
+ end
+
+ def evaluate_write(node, value, scope, evaluated_receivers)
+ case node
+ when Prism::MultiTargetNode
+ evaluate_multi_write node, value, scope, evaluated_receivers
+ when Prism::CallNode
+ evaluated_receivers&.[](node.receiver) || evaluate(node.receiver, scope) if node.receiver
+ when Prism::SplatNode
+ evaluate_write node.expression, Types.array_of(value), scope, evaluated_receivers if node.expression
+ when Prism::LocalVariableTargetNode, Prism::GlobalVariableTargetNode, Prism::InstanceVariableTargetNode, Prism::ClassVariableTargetNode, Prism::ConstantTargetNode
+ scope[node.name.to_s] = value
+ when Prism::ConstantPathTargetNode
+ receiver = evaluated_receivers&.[](node.parent) || evaluate(node.parent, scope) if node.parent
+ const_path_write receiver, node.child.name.to_s, value, scope
+ value
+ end
+ end
+
+ def evaluate_multi_write(node, values, scope, evaluated_receivers)
+ pre_targets = node.lefts
+ splat_target = node.rest
+ post_targets = node.rights
+ size = pre_targets.size + (splat_target ? 1 : 0) + post_targets.size
+ values = values.is_a?(Array) ? values.dup : sized_splat(values, :to_ary, size)
+ pre_pairs = pre_targets.zip(values.shift(pre_targets.size))
+ post_pairs = post_targets.zip(values.pop(post_targets.size))
+ splat_pairs = splat_target ? [[splat_target, Types::UnionType[*values]]] : []
+ (pre_pairs + splat_pairs + post_pairs).each do |target, value|
+ evaluate_write target, value || Types::NIL, scope, evaluated_receivers
+ end
+ end
+
+ def evaluate_multi_write_receiver(node, scope, evaluated_receivers)
+ case node
+ when Prism::MultiWriteNode, Prism::MultiTargetNode
+ targets = [*node.lefts, *node.rest, *node.rights]
+ targets.each { evaluate_multi_write_receiver _1, scope, evaluated_receivers }
+ when Prism::CallNode
+ if node.receiver
+ receiver = evaluate(node.receiver, scope)
+ evaluated_receivers[node.receiver] = receiver if evaluated_receivers
+ end
+ if node.arguments
+ node.arguments.arguments&.each do |arg|
+ if arg.is_a? Prism::SplatNode
+ evaluate arg.expression, scope
+ else
+ evaluate arg, scope
+ end
+ end
+ end
+ when Prism::SplatNode
+ evaluate_multi_write_receiver node.expression, scope, evaluated_receivers if node.expression
+ end
+ end
+
+ def evaluate_list_splat_items(list, scope)
+ items = list.flat_map do |node|
+ if node.is_a? Prism::SplatNode
+ next unless node.expression # def f(*); [*]
+
+ splat = evaluate node.expression, scope
+ array_elem, non_array = partition_to_array splat.nonnillable, :to_a
+ [*array_elem, *non_array]
+ else
+ evaluate node, scope
+ end
+ end.compact.uniq
+ Types::UnionType[*items]
+ end
+
+ def sized_splat(value, method, size)
+ array_elem, non_array = partition_to_array value, method
+ values = [Types::UnionType[*array_elem, *non_array]]
+ values += [array_elem] * (size - 1) if array_elem && size >= 1
+ values
+ end
+
+ def partition_to_array(value, method)
+ arrays, non_arrays = value.types.partition { _1.is_a?(Types::InstanceType) && _1.klass == Array }
+ non_arrays.select! do |type|
+ to_array_result = method_call type, method, [], nil, nil, nil, name_match: false
+ if to_array_result.is_a?(Types::InstanceType) && to_array_result.klass == Array
+ arrays << to_array_result
+ false
+ else
+ true
+ end
+ end
+ array_elem = arrays.empty? ? nil : Types::UnionType[*arrays.map { _1.params[:Elem] || Types::OBJECT }]
+ non_array = non_arrays.empty? ? nil : Types::UnionType[*non_arrays]
+ [array_elem, non_array]
+ end
+
+ def method_call(receiver, method_name, args, kwargs, block, scope, name_match: true)
+ methods = Types.rbs_methods receiver, method_name.to_sym, args, kwargs, !!block
+ block_called = false
+ type_breaks = methods.map do |method, given_params, method_params|
+ receiver_vars = receiver.is_a?(Types::InstanceType) ? receiver.params : {}
+ free_vars = method.type.free_variables - receiver_vars.keys.to_set
+ vars = receiver_vars.merge Types.match_free_variables(free_vars, method_params, given_params)
+ if block && method.block
+ params_type = method.block.type.required_positionals.map do |func_param|
+ Types.from_rbs_type func_param.type, receiver, vars
+ end
+ self_type = Types.from_rbs_type method.block.self_type, receiver, vars if method.block.self_type
+ block_response, breaks = block.call params_type, self_type
+ block_called = true
+ vars.merge! Types.match_free_variables(free_vars - vars.keys.to_set, [method.block.type.return_type], [block_response])
+ end
+ if Types.method_return_bottom?(method)
+ [nil, breaks]
+ else
+ [Types.from_rbs_type(method.type.return_type, receiver, vars || {}), breaks]
+ end
+ end
+ block&.call [], nil unless block_called
+ terminates = !type_breaks.empty? && type_breaks.map(&:first).all?(&:nil?)
+ types = type_breaks.map(&:first).compact
+ breaks = type_breaks.map(&:last).compact
+ types << OBJECT_METHODS[method_name.to_sym] if name_match && OBJECT_METHODS.has_key?(method_name.to_sym)
+
+ if method_name.to_sym == :new
+ receiver.types.each do |type|
+ if type.is_a?(Types::SingletonType) && type.module_or_class.is_a?(Class)
+ types << Types::InstanceType.new(type.module_or_class)
+ end
+ end
+ end
+ scope&.terminate if terminates && breaks.empty?
+ Types::UnionType[*types, *breaks]
+ end
+
+ def self.calculate_target_type_scope(binding, parents, target)
+ dig_targets = DigTarget.new(parents, target) do |type, scope|
+ return type, scope
+ end
+ program = parents.first
+ scope = Scope.from_binding(binding, program.locals)
+ new(dig_targets).evaluate program, scope
+ [Types::NIL, scope]
+ end
+ end
+ end
+end
diff --git a/lib/irb/type_completion/types.rb b/lib/irb/type_completion/types.rb
new file mode 100644
index 0000000000..f0f2342ffe
--- /dev/null
+++ b/lib/irb/type_completion/types.rb
@@ -0,0 +1,426 @@
+# frozen_string_literal: true
+
+require_relative 'methods'
+
+module IRB
+ module TypeCompletion
+ module Types
+ OBJECT_TO_TYPE_SAMPLE_SIZE = 50
+
+ singleton_class.attr_reader :rbs_builder, :rbs_load_error
+
+ def self.preload_in_thread
+ return if @preload_started
+
+ @preload_started = true
+ Thread.new do
+ load_rbs_builder
+ end
+ end
+
+ def self.load_rbs_builder
+ require 'rbs'
+ require 'rbs/cli'
+ loader = RBS::CLI::LibraryOptions.new.loader
+ loader.add path: Pathname('sig')
+ @rbs_builder = RBS::DefinitionBuilder.new env: RBS::Environment.from_loader(loader).resolve_type_names
+ rescue LoadError, StandardError => e
+ @rbs_load_error = e
+ nil
+ end
+
+ def self.class_name_of(klass)
+ klass = klass.superclass if klass.singleton_class?
+ Methods::MODULE_NAME_METHOD.bind_call klass
+ end
+
+ def self.rbs_search_method(klass, method_name, singleton)
+ klass.ancestors.each do |ancestor|
+ name = class_name_of ancestor
+ next unless name && rbs_builder
+ type_name = RBS::TypeName(name).absolute!
+ definition = (singleton ? rbs_builder.build_singleton(type_name) : rbs_builder.build_instance(type_name)) rescue nil
+ method = definition.methods[method_name] if definition
+ return method if method
+ end
+ nil
+ end
+
+ def self.method_return_type(type, method_name)
+ receivers = type.types.map do |t|
+ case t
+ in SingletonType
+ [t, t.module_or_class, true]
+ in InstanceType
+ [t, t.klass, false]
+ end
+ end
+ types = receivers.flat_map do |receiver_type, klass, singleton|
+ method = rbs_search_method klass, method_name, singleton
+ next [] unless method
+ method.method_types.map do |method|
+ from_rbs_type(method.type.return_type, receiver_type, {})
+ end
+ end
+ UnionType[*types]
+ end
+
+ def self.rbs_methods(type, method_name, args_types, kwargs_type, has_block)
+ return [] unless rbs_builder
+
+ receivers = type.types.map do |t|
+ case t
+ in SingletonType
+ [t, t.module_or_class, true]
+ in InstanceType
+ [t, t.klass, false]
+ end
+ end
+ has_splat = args_types.include?(nil)
+ methods_with_score = receivers.flat_map do |receiver_type, klass, singleton|
+ method = rbs_search_method klass, method_name, singleton
+ next [] unless method
+ method.method_types.map do |method_type|
+ score = 0
+ score += 2 if !!method_type.block == has_block
+ reqs = method_type.type.required_positionals
+ opts = method_type.type.optional_positionals
+ rest = method_type.type.rest_positionals
+ trailings = method_type.type.trailing_positionals
+ keyreqs = method_type.type.required_keywords
+ keyopts = method_type.type.optional_keywords
+ keyrest = method_type.type.rest_keywords
+ args = args_types
+ if kwargs_type&.any? && keyreqs.empty? && keyopts.empty? && keyrest.nil?
+ kw_value_type = UnionType[*kwargs_type.values]
+ args += [InstanceType.new(Hash, K: SYMBOL, V: kw_value_type)]
+ end
+ if has_splat
+ score += 1 if args.count(&:itself) <= reqs.size + opts.size + trailings.size
+ elsif reqs.size + trailings.size <= args.size && (rest || args.size <= reqs.size + opts.size + trailings.size)
+ score += 2
+ centers = args[reqs.size...-trailings.size]
+ given = args.first(reqs.size) + centers.take(opts.size) + args.last(trailings.size)
+ expected = (reqs + opts.take(centers.size) + trailings).map(&:type)
+ if rest
+ given << UnionType[*centers.drop(opts.size)]
+ expected << rest.type
+ end
+ if given.any?
+ score += given.zip(expected).count do |t, e|
+ e = from_rbs_type e, receiver_type
+ intersect?(t, e) || (intersect?(STRING, e) && t.methods.include?(:to_str)) || (intersect?(INTEGER, e) && t.methods.include?(:to_int)) || (intersect?(ARRAY, e) && t.methods.include?(:to_ary))
+ end.fdiv(given.size)
+ end
+ end
+ [[method_type, given || [], expected || []], score]
+ end
+ end
+ max_score = methods_with_score.map(&:last).max
+ methods_with_score.select { _2 == max_score }.map(&:first)
+ end
+
+ def self.intersect?(a, b)
+ atypes = a.types.group_by(&:class)
+ btypes = b.types.group_by(&:class)
+ if atypes[SingletonType] && btypes[SingletonType]
+ aa, bb = [atypes, btypes].map {|types| types[SingletonType].map(&:module_or_class) }
+ return true if (aa & bb).any?
+ end
+
+ aa, bb = [atypes, btypes].map {|types| (types[InstanceType] || []).map(&:klass) }
+ (aa.flat_map(&:ancestors) & bb).any?
+ end
+
+ def self.type_from_object(object)
+ case object
+ when Array
+ InstanceType.new Array, { Elem: union_type_from_objects(object) }
+ when Hash
+ InstanceType.new Hash, { K: union_type_from_objects(object.keys), V: union_type_from_objects(object.values) }
+ when Module
+ SingletonType.new object
+ else
+ klass = Methods::OBJECT_SINGLETON_CLASS_METHOD.bind_call(object) rescue Methods::OBJECT_CLASS_METHOD.bind_call(object)
+ InstanceType.new klass
+ end
+ end
+
+ def self.union_type_from_objects(objects)
+ values = objects.size <= OBJECT_TO_TYPE_SAMPLE_SIZE ? objects : objects.sample(OBJECT_TO_TYPE_SAMPLE_SIZE)
+ klasses = values.map { Methods::OBJECT_CLASS_METHOD.bind_call(_1) }
+ UnionType[*klasses.uniq.map { InstanceType.new _1 }]
+ end
+
+ class SingletonType
+ attr_reader :module_or_class
+ def initialize(module_or_class)
+ @module_or_class = module_or_class
+ end
+ def transform() = yield(self)
+ def methods() = @module_or_class.methods
+ def all_methods() = methods | Kernel.methods
+ def constants() = @module_or_class.constants
+ def types() = [self]
+ def nillable?() = false
+ def nonnillable() = self
+ def inspect
+ "#{module_or_class}.itself"
+ end
+ end
+
+ class InstanceType
+ attr_reader :klass, :params
+ def initialize(klass, params = {})
+ @klass = klass
+ @params = params
+ end
+ def transform() = yield(self)
+ def methods() = rbs_methods.select { _2.public? }.keys | @klass.instance_methods
+ def all_methods() = rbs_methods.keys | @klass.instance_methods | @klass.private_instance_methods
+ def constants() = []
+ def types() = [self]
+ def nillable?() = (@klass == NilClass)
+ def nonnillable() = self
+ def rbs_methods
+ name = Types.class_name_of(@klass)
+ return {} unless name && Types.rbs_builder
+
+ type_name = RBS::TypeName(name).absolute!
+ Types.rbs_builder.build_instance(type_name).methods rescue {}
+ end
+ def inspect
+ if params.empty?
+ inspect_without_params
+ else
+ params_string = "[#{params.map { "#{_1}: #{_2.inspect}" }.join(', ')}]"
+ "#{inspect_without_params}#{params_string}"
+ end
+ end
+ def inspect_without_params
+ if klass == NilClass
+ 'nil'
+ elsif klass == TrueClass
+ 'true'
+ elsif klass == FalseClass
+ 'false'
+ else
+ klass.singleton_class? ? klass.superclass.to_s : klass.to_s
+ end
+ end
+ end
+
+ NIL = InstanceType.new NilClass
+ OBJECT = InstanceType.new Object
+ TRUE = InstanceType.new TrueClass
+ FALSE = InstanceType.new FalseClass
+ SYMBOL = InstanceType.new Symbol
+ STRING = InstanceType.new String
+ INTEGER = InstanceType.new Integer
+ RANGE = InstanceType.new Range
+ REGEXP = InstanceType.new Regexp
+ FLOAT = InstanceType.new Float
+ RATIONAL = InstanceType.new Rational
+ COMPLEX = InstanceType.new Complex
+ ARRAY = InstanceType.new Array
+ HASH = InstanceType.new Hash
+ CLASS = InstanceType.new Class
+ MODULE = InstanceType.new Module
+ PROC = InstanceType.new Proc
+
+ class UnionType
+ attr_reader :types
+
+ def initialize(*types)
+ @types = []
+ singletons = []
+ instances = {}
+ collect = -> type do
+ case type
+ in UnionType
+ type.types.each(&collect)
+ in InstanceType
+ params = (instances[type.klass] ||= {})
+ type.params.each do |k, v|
+ (params[k] ||= []) << v
+ end
+ in SingletonType
+ singletons << type
+ end
+ end
+ types.each(&collect)
+ @types = singletons.uniq + instances.map do |klass, params|
+ InstanceType.new(klass, params.transform_values { |v| UnionType[*v] })
+ end
+ end
+
+ def transform(&block)
+ UnionType[*types.map(&block)]
+ end
+
+ def nillable?
+ types.any?(&:nillable?)
+ end
+
+ def nonnillable
+ UnionType[*types.reject { _1.is_a?(InstanceType) && _1.klass == NilClass }]
+ end
+
+ def self.[](*types)
+ type = new(*types)
+ if type.types.empty?
+ OBJECT
+ elsif type.types.size == 1
+ type.types.first
+ else
+ type
+ end
+ end
+
+ def methods() = @types.flat_map(&:methods).uniq
+ def all_methods() = @types.flat_map(&:all_methods).uniq
+ def constants() = @types.flat_map(&:constants).uniq
+ def inspect() = @types.map(&:inspect).join(' | ')
+ end
+
+ BOOLEAN = UnionType[TRUE, FALSE]
+
+ def self.array_of(*types)
+ type = types.size >= 2 ? UnionType[*types] : types.first || OBJECT
+ InstanceType.new Array, Elem: type
+ end
+
+ def self.from_rbs_type(return_type, self_type, extra_vars = {})
+ case return_type
+ when RBS::Types::Bases::Self
+ self_type
+ when RBS::Types::Bases::Bottom, RBS::Types::Bases::Nil
+ NIL
+ when RBS::Types::Bases::Any, RBS::Types::Bases::Void
+ OBJECT
+ when RBS::Types::Bases::Class
+ self_type.transform do |type|
+ case type
+ in SingletonType
+ InstanceType.new(self_type.module_or_class.is_a?(Class) ? Class : Module)
+ in InstanceType
+ SingletonType.new type.klass
+ end
+ end
+ UnionType[*types]
+ when RBS::Types::Bases::Bool
+ BOOLEAN
+ when RBS::Types::Bases::Instance
+ self_type.transform do |type|
+ if type.is_a?(SingletonType) && type.module_or_class.is_a?(Class)
+ InstanceType.new type.module_or_class
+ else
+ OBJECT
+ end
+ end
+ when RBS::Types::Union
+ UnionType[*return_type.types.map { from_rbs_type _1, self_type, extra_vars }]
+ when RBS::Types::Proc
+ PROC
+ when RBS::Types::Tuple
+ elem = UnionType[*return_type.types.map { from_rbs_type _1, self_type, extra_vars }]
+ InstanceType.new Array, Elem: elem
+ when RBS::Types::Record
+ InstanceType.new Hash, K: SYMBOL, V: OBJECT
+ when RBS::Types::Literal
+ InstanceType.new return_type.literal.class
+ when RBS::Types::Variable
+ if extra_vars.key? return_type.name
+ extra_vars[return_type.name]
+ elsif self_type.is_a? InstanceType
+ self_type.params[return_type.name] || OBJECT
+ elsif self_type.is_a? UnionType
+ types = self_type.types.filter_map do |t|
+ t.params[return_type.name] if t.is_a? InstanceType
+ end
+ UnionType[*types]
+ else
+ OBJECT
+ end
+ when RBS::Types::Optional
+ UnionType[from_rbs_type(return_type.type, self_type, extra_vars), NIL]
+ when RBS::Types::Alias
+ case return_type.name.name
+ when :int
+ INTEGER
+ when :boolish
+ BOOLEAN
+ when :string
+ STRING
+ else
+ # TODO: ???
+ OBJECT
+ end
+ when RBS::Types::Interface
+ # unimplemented
+ OBJECT
+ when RBS::Types::ClassInstance
+ klass = return_type.name.to_namespace.path.reduce(Object) { _1.const_get _2 }
+ if return_type.args
+ args = return_type.args.map { from_rbs_type _1, self_type, extra_vars }
+ names = rbs_builder.build_singleton(return_type.name).type_params
+ params = names.map.with_index { [_1, args[_2] || OBJECT] }.to_h
+ end
+ InstanceType.new klass, params || {}
+ end
+ end
+
+ def self.method_return_bottom?(method)
+ method.type.return_type.is_a? RBS::Types::Bases::Bottom
+ end
+
+ def self.match_free_variables(vars, types, values)
+ accumulator = {}
+ types.zip values do |t, v|
+ _match_free_variable(vars, t, v, accumulator) if v
+ end
+ accumulator.transform_values { UnionType[*_1] }
+ end
+
+ def self._match_free_variable(vars, rbs_type, value, accumulator)
+ case [rbs_type, value]
+ in [RBS::Types::Variable,]
+ (accumulator[rbs_type.name] ||= []) << value if vars.include? rbs_type.name
+ in [RBS::Types::ClassInstance, InstanceType]
+ names = rbs_builder.build_singleton(rbs_type.name).type_params
+ names.zip(rbs_type.args).each do |name, arg|
+ v = value.params[name]
+ _match_free_variable vars, arg, v, accumulator if v
+ end
+ in [RBS::Types::Tuple, InstanceType] if value.klass == Array
+ v = value.params[:Elem]
+ rbs_type.types.each do |t|
+ _match_free_variable vars, t, v, accumulator
+ end
+ in [RBS::Types::Record, InstanceType] if value.klass == Hash
+ # TODO
+ in [RBS::Types::Interface,]
+ definition = rbs_builder.build_interface rbs_type.name
+ convert = {}
+ definition.type_params.zip(rbs_type.args).each do |from, arg|
+ convert[from] = arg.name if arg.is_a? RBS::Types::Variable
+ end
+ return if convert.empty?
+ ac = {}
+ definition.methods.each do |method_name, method|
+ return_type = method_return_type value, method_name
+ method.defs.each do |method_def|
+ interface_return_type = method_def.type.type.return_type
+ _match_free_variable convert, interface_return_type, return_type, ac
+ end
+ end
+ convert.each do |from, to|
+ values = ac[from]
+ (accumulator[to] ||= []).concat values if values
+ end
+ else
+ end
+ end
+ end
+ end
+end