diff options
-rw-r--r-- | internal.h | 1 | ||||
-rw-r--r-- | numeric.c | 42 | ||||
-rw-r--r-- | test/ruby/test_enumerator.rb | 15 |
3 files changed, 56 insertions, 2 deletions
diff --git a/internal.h b/internal.h index 6e6a071f92..5b66192ddf 100644 --- a/internal.h +++ b/internal.h @@ -154,6 +154,7 @@ void Init_newline(void); /* numeric.c */ int rb_num_to_uint(VALUE val, unsigned int *ret); +VALUE num_interval_step_size(VALUE from, VALUE to, VALUE step, int excl); int ruby_float_step(VALUE from, VALUE to, VALUE step, int excl); double ruby_float_mod(double x, double y); @@ -100,7 +100,7 @@ static VALUE fix_uminus(VALUE num); static VALUE fix_mul(VALUE x, VALUE y); static VALUE int_pow(long x, unsigned long y); -static ID id_coerce, id_to_i, id_eq; +static ID id_coerce, id_to_i, id_eq, id_div; VALUE rb_cNumeric; VALUE rb_cFloat; @@ -1764,6 +1764,43 @@ ruby_float_step(VALUE from, VALUE to, VALUE step, int excl) return FALSE; } +VALUE +num_interval_step_size(VALUE from, VALUE to, VALUE step, int excl) { + if (FIXNUM_P(from) && FIXNUM_P(to) && FIXNUM_P(step)) { + long delta, diff, result; + + diff = FIX2LONG(step); + delta = FIX2LONG(to) - FIX2LONG(from); + if (excl) { + delta += (diff > 0 ? -1 : +1); + } + result = delta / diff; + return LONG2FIX(result >= 0 ? result + 1 : 0); + } + else if (TYPE(from) == T_FLOAT || TYPE(to) == T_FLOAT || TYPE(step) == T_FLOAT) { + double n = ruby_float_step_size(NUM2DBL(from), NUM2DBL(to), NUM2DBL(step), excl); + + if (isinf(n)) return DBL2NUM(n); + return LONG2FIX(n); + } + else { + VALUE result; + ID cmp = RTEST(rb_funcall(step, '>', 1, INT2FIX(0))) ? '>' : '<'; + if (RTEST(rb_funcall(from, cmp, 1, to))) return INT2FIX(0); + result = rb_funcall(rb_funcall(to, '-', 1, from), id_div, 1, step); + if (!excl || RTEST(rb_funcall(rb_funcall(from, '+', 1, rb_funcall(result, '*', 1, step)), cmp, 1, to))) { + result = rb_funcall(result, '+', 1, INT2FIX(1)); + } + return result; + } +} + +static VALUE +num_step_size(VALUE from, VALUE args) { + VALUE to = RARRAY_PTR(args)[0]; + VALUE step = (RARRAY_LEN(args) > 1) ? RARRAY_PTR(args)[1] : INT2FIX(1); + return num_interval_step_size(from, to, step, FALSE); +} /* * call-seq: * num.step(limit[, step]) {|i| block } -> self @@ -1799,7 +1836,7 @@ num_step(int argc, VALUE *argv, VALUE from) { VALUE to, step; - RETURN_ENUMERATOR(from, argc, argv); + RETURN_SIZED_ENUMERATOR(from, argc, argv, num_step_size); if (argc == 1) { to = argv[0]; step = INT2FIX(1); @@ -3602,6 +3639,7 @@ Init_Numeric(void) id_coerce = rb_intern("coerce"); id_to_i = rb_intern("to_i"); id_eq = rb_intern("=="); + id_div = rb_intern("div"); rb_eZeroDivError = rb_define_class("ZeroDivisionError", rb_eStandardError); rb_eFloatDomainError = rb_define_class("FloatDomainError", rb_eRangeError); diff --git a/test/ruby/test_enumerator.rb b/test/ruby/test_enumerator.rb index c45e3eb0a7..8203aabc56 100644 --- a/test/ruby/test_enumerator.rb +++ b/test/ruby/test_enumerator.rb @@ -522,5 +522,20 @@ class TestEnumerator < Test::Unit::TestCase assert_equal 0, @sized.each_cons(70).size assert_raise(ArgumentError){ @obj.each_cons(0).size } end + + def test_size_for_step + assert_equal 42, 5.step(46).size + assert_equal 4, 1.step(10, 3).size + assert_equal 3, 1.step(9, 3).size + assert_equal 0, 1.step(-11).size + assert_equal 0, 1.step(-11, 2).size + assert_equal 7, 1.step(-11, -2).size + assert_equal 7, 1.step(-11.1, -2).size + assert_equal 0, 42.step(Float::INFINITY, -2).size + assert_equal 1, 42.step(55, Float::INFINITY).size + assert_equal 1, 42.step(Float::INFINITY, Float::INFINITY).size + assert_equal 14, 0.1.step(4.2, 0.3).size + assert_equal Float::INFINITY, 42.step(Float::INFINITY, 2).size + end end |