diff options
author | nobu <nobu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e> | 2010-08-26 22:57:39 +0000 |
---|---|---|
committer | nobu <nobu@b2dd03c8-39d4-4d8f-98ff-823fe69b080e> | 2010-08-26 22:57:39 +0000 |
commit | 5b7ccc0629baa7cd2c7ab92802ee1bf62e3ec0f4 (patch) | |
tree | 145ed61a619f6c58d4df760ea8c727527456f841 /array.c | |
parent | 6e08cd5ec2e4133bbb02ccd196206c60f9cb6795 (diff) | |
download | ruby-5b7ccc0629baa7cd2c7ab92802ee1bf62e3ec0f4.tar.gz |
* array.c (rb_ary_shuffle_bang): bail out from modification during
shuffle.
* array.c (rb_ary_sample): ditto.
git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@29108 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
Diffstat (limited to 'array.c')
-rw-r--r-- | array.c | 64 |
1 files changed, 46 insertions, 18 deletions
@@ -20,6 +20,8 @@ #endif #include <assert.h> +#define numberof(array) (int)(sizeof(array) / sizeof((array)[0])) + VALUE rb_cArray; static ID id_cmp; @@ -3748,8 +3750,8 @@ static VALUE sym_random; static VALUE rb_ary_shuffle_bang(int argc, VALUE *argv, VALUE ary) { - VALUE *ptr, opts, randgen = rb_cRandom; - long i = RARRAY_LEN(ary); + VALUE *ptr, opts, *snap_ptr, randgen = rb_cRandom; + long i, snap_len; if (OPTHASH_GIVEN_P(opts)) { randgen = rb_hash_lookup2(opts, sym_random, randgen); @@ -3758,10 +3760,17 @@ rb_ary_shuffle_bang(int argc, VALUE *argv, VALUE ary) rb_raise(rb_eArgError, "wrong number of arguments (%d for 0)", argc); } rb_ary_modify(ary); + i = RARRAY_LEN(ary); ptr = RARRAY_PTR(ary); + snap_len = i; + snap_ptr = ptr; while (i) { long j = RAND_UPTO(i); - VALUE tmp = ptr[--i]; + VALUE tmp; + if (snap_len != RARRAY_LEN(ary) || snap_ptr != RARRAY_PTR(ary)) { + rb_raise(rb_eRuntimeError, "modified during shuffle"); + } + tmp = ptr[--i]; ptr[i] = ptr[j]; ptr[j] = tmp; } @@ -3814,37 +3823,54 @@ static VALUE rb_ary_sample(int argc, VALUE *argv, VALUE ary) { VALUE nv, result, *ptr; - VALUE opts, randgen = rb_cRandom; + VALUE opts, snap, randgen = rb_cRandom; long n, len, i, j, k, idx[10]; + double rnds[numberof(idx)]; - len = RARRAY_LEN(ary); if (OPTHASH_GIVEN_P(opts)) { randgen = rb_hash_lookup2(opts, sym_random, randgen); } + ptr = RARRAY_PTR(ary); + len = RARRAY_LEN(ary); if (argc == 0) { if (len == 0) return Qnil; - i = len == 1 ? 0 : RAND_UPTO(len); + if (len == 1) { + i = 0; + } + else { + double x = rb_random_real(randgen); + if ((len = RARRAY_LEN(ary)) == 0) return Qnil; + i = (long)(x * len); + } return RARRAY_PTR(ary)[i]; } rb_scan_args(argc, argv, "1", &nv); n = NUM2LONG(nv); if (n < 0) rb_raise(rb_eArgError, "negative sample number"); - ptr = RARRAY_PTR(ary); + if (n > len) n = len; + if (n <= numberof(idx)) { + for (i = 0; i < n; ++i) { + rnds[i] = rb_random_real(randgen); + } + } len = RARRAY_LEN(ary); + ptr = RARRAY_PTR(ary); if (n > len) n = len; switch (n) { - case 0: return rb_ary_new2(0); + case 0: + return rb_ary_new2(0); case 1: - return rb_ary_new4(1, &ptr[RAND_UPTO(len)]); + i = (long)(rnds[0] * len); + return rb_ary_new4(1, &ptr[i]); case 2: - i = RAND_UPTO(len); - j = RAND_UPTO(len-1); + i = (long)(rnds[0] * len); + j = (long)(rnds[1] * (len-1)); if (j >= i) j++; return rb_ary_new3(2, ptr[i], ptr[j]); case 3: - i = RAND_UPTO(len); - j = RAND_UPTO(len-1); - k = RAND_UPTO(len-2); + i = (long)(rnds[0] * len); + j = (long)(rnds[1] * (len-1)); + k = (long)(rnds[2] * (len-2)); { long l = j, g = i; if (j >= i) l = i, g = ++j; @@ -3852,12 +3878,12 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) } return rb_ary_new3(3, ptr[i], ptr[j], ptr[k]); } - if ((size_t)n < sizeof(idx)/sizeof(idx[0])) { + if (n <= numberof(idx)) { VALUE *ptr_result; - long sorted[sizeof(idx)/sizeof(idx[0])]; - sorted[0] = idx[0] = RAND_UPTO(len); + long sorted[numberof(idx)]; + sorted[0] = idx[0] = (long)(rnds[0] * len); for (i=1; i<n; i++) { - k = RAND_UPTO(--len); + k = (long)(rnds[i] * --len); for (j = 0; j < i; ++j) { if (k < sorted[j]) break; ++k; @@ -3874,6 +3900,7 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) else { VALUE *ptr_result; result = rb_ary_new4(len, ptr); + RBASIC(result)->klass = 0; ptr_result = RARRAY_PTR(result); RB_GC_GUARD(ary); for (i=0; i<n; i++) { @@ -3882,6 +3909,7 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) ptr_result[j] = ptr_result[i]; ptr_result[i] = nv; } + RBASIC(result)->klass = rb_cArray; } ARY_SET_LEN(result, n); |