diff options
-rw-r--r-- | enumerator.c | 71 | ||||
-rw-r--r-- | include/ruby/intern.h | 7 | ||||
-rw-r--r-- | test/ruby/test_enumerator.rb | 10 |
3 files changed, 74 insertions, 14 deletions
diff --git a/enumerator.c b/enumerator.c index d5d276f6aa..7695ea1a56 100644 --- a/enumerator.c +++ b/enumerator.c @@ -119,6 +119,8 @@ struct enumerator { VALUE lookahead; VALUE feedvalue; VALUE stop_exc; + VALUE size; + VALUE (*size_fn)(ANYARGS); }; static VALUE rb_cGenerator, rb_cYielder; @@ -148,6 +150,7 @@ enumerator_mark(void *p) rb_gc_mark(ptr->lookahead); rb_gc_mark(ptr->feedvalue); rb_gc_mark(ptr->stop_exc); + rb_gc_mark(ptr->size); } #define enumerator_free RUBY_TYPED_DEFAULT_FREE @@ -216,7 +219,7 @@ obj_to_enum(int argc, VALUE *argv, VALUE obj) --argc; meth = *argv++; } - return rb_enumeratorize(obj, meth, argc, argv); + return rb_enumeratorize(obj, meth, argc, argv, 0); } static VALUE @@ -232,7 +235,7 @@ enumerator_allocate(VALUE klass) } static VALUE -enumerator_init(VALUE enum_obj, VALUE obj, VALUE meth, int argc, VALUE *argv) +enumerator_init(VALUE enum_obj, VALUE obj, VALUE meth, int argc, VALUE *argv, VALUE (*size_fn)(ANYARGS), VALUE size) { struct enumerator *ptr; @@ -250,13 +253,15 @@ enumerator_init(VALUE enum_obj, VALUE obj, VALUE meth, int argc, VALUE *argv) ptr->lookahead = Qundef; ptr->feedvalue = Qundef; ptr->stop_exc = Qfalse; + ptr->size = size; + ptr->size_fn = size_fn; return enum_obj; } /* * call-seq: - * Enumerator.new { |yielder| ... } + * Enumerator.new(size = nil) { |yielder| ... } * Enumerator.new(obj, method = :each, *args) * * Creates a new Enumerator object, which can be used as an @@ -276,6 +281,10 @@ enumerator_init(VALUE enum_obj, VALUE obj, VALUE meth, int argc, VALUE *argv) * * p fib.take(10) # => [1, 1, 2, 3, 5, 8, 13, 21, 34, 55] * + * The optional parameter can be used to specify how to calculate the size + * in a lazy fashion (see Enumerable#size). It can either be a value or + * a callable object. + * * The block form can be used to create a lazy enumeration that only processes * elements as-needed. The generic pattern for this is: * @@ -349,14 +358,23 @@ static VALUE enumerator_initialize(int argc, VALUE *argv, VALUE obj) { VALUE recv, meth = sym_each; + VALUE size = Qnil; - if (argc == 0) { - if (!rb_block_given_p()) - rb_check_arity(argc, 1, UNLIMITED_ARGUMENTS); - + if (rb_block_given_p()) { + rb_check_arity(argc, 0, 1); recv = generator_init(generator_allocate(rb_cGenerator), rb_block_proc()); + if (argc) { + if (NIL_P(argv[0]) || rb_obj_is_proc(argv[0]) || + (TYPE(argv[0]) == T_FLOAT && RFLOAT_VALUE(argv[0]) == INFINITY)) { + size = argv[0]; + } else { + size = rb_to_int(argv[0]); + } + argc = 0; + } } else { + rb_check_arity(argc, 1, UNLIMITED_ARGUMENTS); rb_warn("Enumerator.new without a block is deprecated; use Object#to_enum"); recv = *argv++; if (--argc) { @@ -365,7 +383,7 @@ enumerator_initialize(int argc, VALUE *argv, VALUE obj) } } - return enumerator_init(obj, recv, meth, argc, argv); + return enumerator_init(obj, recv, meth, argc, argv, 0, size); } /* :nodoc: */ @@ -393,14 +411,16 @@ enumerator_init_copy(VALUE obj, VALUE orig) ptr1->fib = 0; ptr1->lookahead = Qundef; ptr1->feedvalue = Qundef; + ptr1->size = ptr0->size; + ptr1->size_fn = ptr0->size_fn; return obj; } VALUE -rb_enumeratorize(VALUE obj, VALUE meth, int argc, VALUE *argv) +rb_enumeratorize(VALUE obj, VALUE meth, int argc, VALUE *argv, VALUE (*size_fn)(ANYARGS)) { - return enumerator_init(enumerator_allocate(rb_cEnumerator), obj, meth, argc, argv); + return enumerator_init(enumerator_allocate(rb_cEnumerator), obj, meth, argc, argv, size_fn, Qnil); } static VALUE @@ -943,6 +963,34 @@ enumerator_inspect(VALUE obj) } /* + * call-seq: + * e.size -> int, Float::INFINITY or nil + * + * Returns the size of the enumerator, or +nil+ if it can't be calculated lazily. + * + * (1..100).to_a.permutation(4).size # => 94109400 + * loop.size # => Float::INFINITY + * (1..100).drop_while.size # => nil + */ + +static VALUE +enumerator_size(VALUE obj) +{ + struct enumerator *e = enumerator_ptr(obj); + + if (e->size_fn) { + return (*e->size_fn)(e->obj, e->args); + } + if (rb_obj_is_proc(e->size)) { + if(e->args) + return rb_proc_call(e->size, e->args); + else + return rb_proc_call_with_block(e->size, 0, 0, Qnil); + } + return e->size; +} + +/* * Yielder */ static void @@ -1253,7 +1301,7 @@ lazy_initialize(int argc, VALUE *argv, VALUE self) rb_block_call(generator, id_initialize, 0, 0, (rb_block_given_p() ? lazy_init_block_i : lazy_init_block), obj); - enumerator_init(self, generator, meth, argc - offset, argv + offset); + enumerator_init(self, generator, meth, argc - offset, argv + offset, 0, Qnil); rb_ivar_set(self, id_receiver, obj); return self; @@ -1749,6 +1797,7 @@ InitVM_Enumerator(void) rb_define_method(rb_cEnumerator, "feed", enumerator_feed, 1); rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0); rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0); + rb_define_method(rb_cEnumerator, "size", enumerator_size, 0); /* Lazy */ rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator); diff --git a/include/ruby/intern.h b/include/ruby/intern.h index 5e68dbdcd3..1799e79317 100644 --- a/include/ruby/intern.h +++ b/include/ruby/intern.h @@ -201,12 +201,13 @@ VALUE rb_fiber_alive_p(VALUE); /* enum.c */ VALUE rb_enum_values_pack(int, VALUE*); /* enumerator.c */ -VALUE rb_enumeratorize(VALUE, VALUE, int, VALUE *); -#define RETURN_ENUMERATOR(obj, argc, argv) do { \ +VALUE rb_enumeratorize(VALUE, VALUE, int, VALUE *, VALUE (*)(ANYARGS)); +#define RETURN_SIZED_ENUMERATOR(obj, argc, argv, size_fn) do { \ if (!rb_block_given_p()) \ return rb_enumeratorize((obj), ID2SYM(rb_frame_this_func()),\ - (argc), (argv)); \ + (argc), (argv), (size_fn)); \ } while (0) +#define RETURN_ENUMERATOR(obj, argc, argv) RETURN_SIZED_ENUMERATOR(obj, argc, argv, 0) /* error.c */ VALUE rb_exc_new(VALUE, const char*, long); VALUE rb_exc_new2(VALUE, const char*); diff --git a/test/ruby/test_enumerator.rb b/test/ruby/test_enumerator.rb index 6b8a09f80f..0f19a4f8e7 100644 --- a/test/ruby/test_enumerator.rb +++ b/test/ruby/test_enumerator.rb @@ -401,5 +401,15 @@ class TestEnumerator < Test::Unit::TestCase assert_raise(LocalJumpError) { Enumerator::Yielder.new } end + + def test_size + assert_equal nil, Enumerator.new{}.size + assert_equal 42, Enumerator.new(->{42}){}.size + assert_equal 42, Enumerator.new(42){}.size + assert_equal 1 << 70, Enumerator.new(1 << 70){}.size + assert_equal Float::INFINITY, Enumerator.new(Float::INFINITY){}.size + assert_equal nil, Enumerator.new(nil){}.size + assert_raise(TypeError) { Enumerator.new("42"){} } + end end |