From 80b537f14c4af699b26e52931ae7e64a547e68c5 Mon Sep 17 00:00:00 2001 From: Kazuki Yamaguchi Date: Wed, 17 Feb 2016 17:42:09 +0900 Subject: {Enumerable,Array,Range}#first, {Array,Range}#last with block * array.c (rb_ary_first, ary_last): extend Array#{first,last} to accept a block. If a block is passed, these methods collects only elements for which the block returns a truthy value. * enum.c: extend Enumerable#first to accept a block. * range.c: extend Range#{first,last} to accept a block. * gc.c: avoid using rb_ary_last(), because it may call a block. * test/ruby/test_array.rb: add test * test/ruby/test_enum.rb: ditto * test/ruby/test_range.rb: ditto --- array.c | 109 +++++++++++++++++++++++++++++++++++++----------- enum.c | 47 ++++++++++++++++----- gc.c | 4 +- range.c | 49 +++++++++++++++------- test/ruby/test_array.rb | 10 +++++ test/ruby/test_enum.rb | 2 + test/ruby/test_range.rb | 8 ++++ 7 files changed, 178 insertions(+), 51 deletions(-) diff --git a/array.c b/array.c index 7a01dbc984..9457400013 100644 --- a/array.c +++ b/array.c @@ -1310,56 +1310,117 @@ rb_ary_at(VALUE ary, VALUE pos) /* * call-seq: - * ary.first -> obj or nil - * ary.first(n) -> new_ary + * ary.first -> obj or nil + * ary.first { |obj| block } -> obj or nil + * ary.first(n) -> new_ary + * ary.first(n) { |obj| block } -> new_ary * * Returns the first element, or the first +n+ elements, of the array. - * If the array is empty, the first form returns +nil+, and the - * second form returns an empty array. See also Array#last for - * the opposite effect. + * If the array is empty, the first form returns +nil+, and the second form + * returns an empty array. + * If a block is given, only elements for which the given block returns a true + * value are counted. + * + * See also Array#last for the opposite effect. * - * a = [ "q", "r", "s", "t" ] - * a.first #=> "q" - * a.first(2) #=> ["q", "r"] + * a = [ "q", "r", "s", "t", "aa" ] + * a.first #=> "q" + * a.first(2) #=> ["q", "r"] + * a.first { |i| i.size > 1 } #=> "aa" */ static VALUE rb_ary_first(int argc, VALUE *argv, VALUE ary) { - if (argc == 0) { - if (RARRAY_LEN(ary) == 0) return Qnil; - return RARRAY_AREF(ary, 0); + if (!rb_block_given_p()) { + if (argc == 0) { + if (RARRAY_LEN(ary) == 0) return Qnil; + return RARRAY_AREF(ary, 0); + } + else { + return ary_take_first_or_last(argc, argv, ary, ARY_TAKE_FIRST); + } } else { - return ary_take_first_or_last(argc, argv, ary, ARY_TAKE_FIRST); + long i; + if (argc == 0) { + for (i = 0; i < RARRAY_LEN(ary); i++) { + if (RTEST(rb_yield(RARRAY_AREF(ary, i)))) + return RARRAY_AREF(ary, i); + } + return Qnil; + } + else { + long take = NUM2LONG(argv[0]); + VALUE result = rb_ary_new();; + if (take < 0) rb_raise(rb_eArgError, "attempt to take negative size"); + if (take == 0) return rb_ary_new2(0); + for (i = 0; i < RARRAY_LEN(ary); i++) { + if (RTEST(rb_yield(RARRAY_AREF(ary, i)))) { + rb_ary_push(result, RARRAY_AREF(ary, i)); + if (!--take) break; + } + } + return result; + } } } /* * call-seq: - * ary.last -> obj or nil - * ary.last(n) -> new_ary + * ary.last -> obj or nil + * ary.last { |obj| block } -> obj or nil + * ary.last(n) -> new_ary + * ary.last(n) { |obj| block } -> new_ary * - * Returns the last element(s) of +self+. If the array is empty, - * the first form returns +nil+. + * Returns the last element(s) of +self+. If the array is empty, the first + * form returns +nil+. + * If a block is given, only elements for which the given block returns a true + * value are counted. * * See also Array#first for the opposite effect. * - * a = [ "w", "x", "y", "z" ] - * a.last #=> "z" - * a.last(2) #=> ["y", "z"] + * a = [ "w", "x", "y", "z", "aa" ] + * a.last #=> "aa" + * a.last(2) #=> ["z", "aa"] + * a.last { |i| i.size == 1 } #=> "x" */ VALUE rb_ary_last(int argc, const VALUE *argv, VALUE ary) { - if (argc == 0) { - long len = RARRAY_LEN(ary); - if (len == 0) return Qnil; - return RARRAY_AREF(ary, len-1); + if (!rb_block_given_p()) { + if (argc == 0) { + long len = RARRAY_LEN(ary); + if (len == 0) return Qnil; + return RARRAY_AREF(ary, len-1); + } + else { + return ary_take_first_or_last(argc, argv, ary, ARY_TAKE_LAST); + } } else { - return ary_take_first_or_last(argc, argv, ary, ARY_TAKE_LAST); + long i; + if (argc == 0) { + for (i = RARRAY_LEN(ary); --i >= 0; ) { + if (RTEST(rb_yield(RARRAY_AREF(ary, i)))) + return RARRAY_AREF(ary, i); + } + return Qnil; + } + else { + long take = NUM2LONG(argv[0]); + VALUE result = rb_ary_new();; + if (take < 0) rb_raise(rb_eArgError, "attempt to take negative size"); + if (take == 0) return rb_ary_new2(0); + for (i = RARRAY_LEN(ary); --i >= 0; ) { + if (RTEST(rb_yield(RARRAY_AREF(ary, i)))) { + rb_ary_push(result, RARRAY_AREF(ary, i)); + if (!--take) break; + } + } + return rb_ary_reverse(result); + } } } diff --git a/enum.c b/enum.c index 23e4f5a5e4..e1f0e60d4f 100644 --- a/enum.c +++ b/enum.c @@ -909,22 +909,40 @@ first_i(RB_BLOCK_CALL_FUNC_ARGLIST(i, params)) UNREACHABLE; } +static VALUE +take_find_i(RB_BLOCK_CALL_FUNC_ARGLIST(i, params)) +{ + struct MEMO *memo = MEMO_CAST(params); + ENUM_WANT_SVALUE(); + if (RTEST(rb_yield(i))) { + rb_ary_push(memo->v1, i); + if (!--memo->u3.cnt) rb_iter_break(); + } + return Qnil; +} + static VALUE enum_take(VALUE obj, VALUE n); /* * call-seq: - * enum.first -> obj or nil - * enum.first(n) -> an_array + * enum.first -> obj or nil + * enum.first { |obj| block } -> obj or nil + * enum.first(n) -> an_array + * enum.first(n) { |obj| block } -> an_array * * Returns the first element, or the first +n+ elements, of the enumerable. * If the enumerable is empty, the first form returns nil, and the * second form returns an empty array. + * If a block is given, only elements for which the given block returns a true + * value are counted. * - * %w[foo bar baz].first #=> "foo" - * %w[foo bar baz].first(2) #=> ["foo", "bar"] - * %w[foo bar baz].first(10) #=> ["foo", "bar", "baz"] - * [].first #=> nil - * [].first(10) #=> [] + * %w[foo bar baz].first #=> "foo" + * %w[foo bar baz].first(2) #=> ["foo", "bar"] + * %w[foo bar baz].first(10) #=> ["foo", "bar", "baz"] + * [].first #=> nil + * [].first(10) #=> [] + * [1,2,3,4].first { |i| i.even? } #=> 2 + * [1,2,2,2].first(2) { |i| i > 1 } #=> [2, 2] * */ @@ -932,18 +950,27 @@ static VALUE enum_first(int argc, VALUE *argv, VALUE obj) { struct MEMO *memo; + long len; rb_check_arity(argc, 0, 1); + if (argc > 0) { - return enum_take(obj, argv[0]); + if (!rb_block_given_p()) return enum_take(obj, argv[0]); + + len = NUM2LONG(argv[0]); + if (len < 0) rb_raise(rb_eArgError, "attempt to take negative size"); + if (len == 0) return rb_ary_new2(0); + + memo = MEMO_NEW(rb_ary_new(), 0, len); + rb_block_call(obj, id_each, 0, 0, take_find_i, (VALUE)memo); + return memo->v1; } else { memo = MEMO_NEW(Qnil, 0, 0); - rb_block_call(obj, id_each, 0, 0, first_i, (VALUE)memo); + rb_block_call(obj, id_each, 0, 0, rb_block_given_p() ? find_i : first_i, (VALUE)memo); return memo->v1; } } - /* * call-seq: * enum.sort -> array diff --git a/gc.c b/gc.c index 2baf7a8665..f2853666d7 100644 --- a/gc.c +++ b/gc.c @@ -6094,7 +6094,9 @@ void rb_gc_register_mark_object(VALUE obj) { VALUE ary_ary = GET_THREAD()->vm->mark_object_ary; - VALUE ary = rb_ary_last(0, 0, ary_ary); + VALUE ary = Qnil; + if (RARRAY_LEN(ary_ary)) + ary = RARRAY_AREF(ary_ary, RARRAY_LEN(ary_ary) - 1); if (ary == Qnil || RARRAY_LEN(ary) >= MARK_OBJECT_ARY_BUCKET_SIZE) { ary = rb_ary_tmp_new(MARK_OBJECT_ARY_BUCKET_SIZE); diff --git a/range.c b/range.c index ab3f1af9e2..ea71323261 100644 --- a/range.c +++ b/range.c @@ -851,14 +851,20 @@ first_i(RB_BLOCK_CALL_FUNC_ARGLIST(i, cbarg)) /* * call-seq: - * rng.first -> obj - * rng.first(n) -> an_array + * rng.first -> obj + * rng.first { |obj| block } -> obj + * rng.first(n) -> an_array + * rng.first(n) { |obj| block } -> an_array * * Returns the first object in the range, or an array of the first +n+ * elements. + * If a block is given, only elements for which the given block returns a true + * value are counted. * - * (10..20).first #=> 10 - * (10..20).first(3) #=> [10, 11, 12] + * (10..20).first #=> 10 + * (10..20).first { |i| i.odd? } #=> 11 + * (10..20).first(3) #=> [10, 11, 12] + * (10..20).first(3) { |i| i.odd? } #=> [11, 13, 15] */ static VALUE @@ -866,6 +872,7 @@ range_first(int argc, VALUE *argv, VALUE range) { VALUE n, ary[2]; + if (rb_block_given_p()) return rb_call_super(argc, argv); if (argc == 0) return RANGE_BEG(range); rb_scan_args(argc, argv, "1", &n); @@ -879,26 +886,36 @@ range_first(int argc, VALUE *argv, VALUE range) /* * call-seq: - * rng.last -> obj - * rng.last(n) -> an_array + * rng.last -> obj + * rng.last { |obj| block } -> obj + * rng.last(n) -> an_array + * rng.last(n) { |obj| block } -> an_array * * Returns the last object in the range, * or an array of the last +n+ elements. - * - * Note that with no arguments +last+ will return the object that defines - * the end of the range even if #exclude_end? is +true+. - * - * (10..20).last #=> 20 - * (10...20).last #=> 20 - * (10..20).last(3) #=> [18, 19, 20] - * (10...20).last(3) #=> [17, 18, 19] + * If a block is given, only elements for which the given block returns a true + * value are counted. + * + * Note that with no arguments nor a block +last+ will return the object that + * defines the end of the range even if #exclude_end? is +true+. + * + * (10..20).last #=> 20 + * (10...20).last #=> 20 + * (10...20).last { true } #=> 19 + * (10..20).last(3) #=> [18, 19, 20] + * (10...20).last(3) #=> [17, 18, 19] + * (10...20).last(3) { |i| i.odd? } #=> [15, 17, 19] */ static VALUE range_last(int argc, VALUE *argv, VALUE range) { - if (argc == 0) return RANGE_END(range); - return rb_ary_last(argc, argv, rb_Array(range)); + if (argc > 0 || rb_block_given_p()) { + return rb_ary_last(argc, argv, rb_Array(range)); + } + else { + return RANGE_END(range); + } } diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb index fa158185f8..d0f12b280a 100644 --- a/test/ruby/test_array.rb +++ b/test/ruby/test_array.rb @@ -756,6 +756,8 @@ class TestArray < Test::Unit::TestCase def test_first assert_equal(3, @cls[3, 4, 5].first) assert_equal(nil, @cls[].first) + assert_equal(2, @cls[1, 2, 3, 4].first { |i| i.even? }) + assert_equal(nil, @cls[1, 2, 3, 4].first { |i| i > 100 }) end def test_flatten @@ -1059,6 +1061,8 @@ class TestArray < Test::Unit::TestCase assert_equal(nil, @cls[].last) assert_equal(1, @cls[1].last) assert_equal(99, @cls[*(3..99).to_a].last) + assert_equal(3, @cls[1, 2, 3, 4].last { |i| i.odd? }) + assert_equal(nil, @cls[1, 2, 3, 4].last { |i| i > 100 }) end def test_length @@ -1999,11 +2003,17 @@ class TestArray < Test::Unit::TestCase def test_first2 assert_equal([0], [0].first(2)) assert_raise(ArgumentError) { [0].first(-1) } + assert_equal([2, 4], @cls[1, 2, 4, 6].first(2) { |i| i.even? }) + assert_equal([2, 4, 6], @cls[2, 4, 5, 6].first(10) { |i| i.even? }) + assert_raise(ArgumentError) { @cls[1, 2].first(-1) { |i| i.even? } } end def test_last2 assert_equal([0], [0].last(2)) assert_raise(ArgumentError) { [0].last(-1) } + assert_equal([4, 6], @cls[2, 4, 5, 6].last(2) { |i| i.even? }) + assert_equal([2, 4, 6], @cls[2, 4, 5, 6].last(10) { |i| i.even? }) + assert_raise(ArgumentError) { @cls[1, 2].last(-1) { |i| i.even? } } end def test_shift2 diff --git a/test/ruby/test_enum.rb b/test/ruby/test_enum.rb index ba973e2d48..c230654751 100644 --- a/test/ruby/test_enum.rb +++ b/test/ruby/test_enum.rb @@ -267,6 +267,8 @@ class TestEnumerable < Test::Unit::TestCase assert_equal([1, 2, 3], @obj.first(3)) assert_nil(@empty.first) assert_equal([], @empty.first(10)) + assert_equal(2, @obj.first { |i| i.even? }) + assert_equal([3], @obj.first(2) { |i| i > 2 }) bug5801 = '[ruby-dev:45041]' assert_in_out_err([], <<-'end;', [], /unexpected break/, bug5801) diff --git a/test/ruby/test_range.rb b/test/ruby/test_range.rb index 2fc2a2b1d3..733d45b6f4 100644 --- a/test/ruby/test_range.rb +++ b/test/ruby/test_range.rb @@ -271,6 +271,10 @@ class TestRange < Test::Unit::TestCase assert_equal("a", ("a".."c").first) assert_equal("c", ("a".."c").last) assert_equal(0, (2..0).last) + assert_equal(1, (0..11).first { |i| i.odd? }) + assert_equal(11, (0..11).last { |i| i.odd? }) + assert_equal([1, 3], (0..11).first(2) { |i| i.odd? }) + assert_equal([9, 11], (0..11).last(2) { |i| i.odd? }) assert_equal([0, 1, 2], (0...10).first(3)) assert_equal([7, 8, 9], (0...10).last(3)) @@ -279,6 +283,10 @@ class TestRange < Test::Unit::TestCase assert_equal("a", ("a"..."c").first) assert_equal("c", ("a"..."c").last) assert_equal(0, (2...0).last) + assert_equal(1, (0...11).first { |i| i.odd? }) + assert_equal(9, (0...11).last { |i| i.odd? }) + assert_equal([1, 3], (0...11).first(2) { |i| i.odd? }) + assert_equal([7, 9], (0...11).last(2) { |i| i.odd? }) end def test_to_s -- cgit v1.2.3