aboutsummaryrefslogtreecommitdiffstats
path: root/range.c
diff options
context:
space:
mode:
authorKouhei Yanagita <yanagi@shakenbu.org>2023-09-19 20:18:00 +0900
committerNobuyoshi Nakada <nobu@ruby-lang.org>2023-09-26 17:31:10 +0900
commit4199e49cad0adddb48d58fa2b0a50563bfd40dac (patch)
tree00c7a78e924dfa85fc91b17bfe95043ade648a6e /range.c
parent91042ec0ae2a8285ad68c101ff384bd7e3f4e260 (diff)
downloadruby-4199e49cad0adddb48d58fa2b0a50563bfd40dac.tar.gz
Optimize Range#bsearch by reducing the number of Integer#+ calls
Diffstat (limited to 'range.c')
-rw-r--r--range.c51
1 files changed, 28 insertions, 23 deletions
diff --git a/range.c b/range.c
index a8bdb98b53..199d1056db 100644
--- a/range.c
+++ b/range.c
@@ -649,27 +649,30 @@ bsearch_integer_range(VALUE beg, VALUE end, int excl)
VALUE low = rb_to_int(beg);
VALUE high = rb_to_int(end);
- VALUE mid, org_high;
+ VALUE mid;
ID id_div;
CONST_ID(id_div, "div");
- if (excl) high = rb_funcall(high, '-', 1, INT2FIX(1));
- org_high = high;
+ if (!excl) high = rb_funcall(high, '+', 1, INT2FIX(1));
+ low = rb_funcall(low, '-', 1, INT2FIX(1));
- while (rb_cmpint(rb_funcall(low, id_cmp, 1, high), low, high) < 0) {
- mid = rb_funcall(rb_funcall(high, '+', 1, low), id_div, 1, INT2FIX(2));
+ /*
+ * This loop must continue while low + 1 < high.
+ * Instead of checking low + 1 < high, check low < mid, where mid = (low + high) / 2.
+ * This is to avoid the cost of calculating low + 1 on each iteration.
+ * Note that this condition replacement is valid because Integer#div always rounds
+ * towards negative infinity.
+ */
+ while (mid = rb_funcall(rb_funcall(high, '+', 1, low), id_div, 1, INT2FIX(2)),
+ rb_cmpint(rb_funcall(low, id_cmp, 1, mid), low, mid) < 0) {
BSEARCH_CHECK(mid);
if (smaller) {
high = mid;
}
else {
- low = rb_funcall(mid, '+', 1, INT2FIX(1));
+ low = mid;
}
}
- if (rb_equal(low, org_high)) {
- BSEARCH_CHECK(low);
- if (!smaller) return Qnil;
- }
return satisfied;
}
@@ -696,8 +699,14 @@ range_bsearch(VALUE range)
* by the mantissa. This is true with or without implicit bit.
*
* Finding the average of two ints needs to be careful about
- * potential overflow (since float to long can use 64 bits)
- * as well as the fact that -1/2 can be 0 or -1 in C89.
+ * potential overflow (since float to long can use 64 bits).
+ *
+ * The half-open interval (low, high] indicates where the target is located.
+ * The loop continues until low and high are adjacent.
+ *
+ * -1/2 can be either 0 or -1 in C89. However, when low and high are not adjacent,
+ * the rounding direction of mid = (low + high) / 2 does not affect the result of
+ * the binary search.
*
* Note that -0.0 is mapped to the same int as 0.0 as we don't want
* (-1...0.0).bsearch to yield -0.0.
@@ -706,23 +715,19 @@ range_bsearch(VALUE range)
#define BSEARCH(conv, excl) \
do { \
RETURN_ENUMERATOR(range, 0, 0); \
- if (excl) high--; \
- org_high = high; \
- while (low < high) { \
+ if (!(excl)) high++; \
+ low--; \
+ while (low + 1 < high) { \
mid = ((high < 0) == (low < 0)) ? low + ((high - low) / 2) \
- : (low < -high) ? -((-1 - low - high)/2 + 1) : (low + high) / 2; \
+ : (low + high) / 2; \
BSEARCH_CHECK(conv(mid)); \
if (smaller) { \
high = mid; \
} \
else { \
- low = mid + 1; \
+ low = mid; \
} \
} \
- if (low == org_high) { \
- BSEARCH_CHECK(conv(low)); \
- if (!smaller) return Qnil; \
- } \
return satisfied; \
} while (0)
@@ -730,7 +735,7 @@ range_bsearch(VALUE range)
do { \
long low = FIX2LONG(beg); \
long high = FIX2LONG(end); \
- long mid, org_high; \
+ long mid; \
BSEARCH(INT2FIX, (excl)); \
} while (0)
@@ -744,7 +749,7 @@ range_bsearch(VALUE range)
else if (RB_FLOAT_TYPE_P(beg) || RB_FLOAT_TYPE_P(end)) {
int64_t low = double_as_int64(NIL_P(beg) ? -HUGE_VAL : RFLOAT_VALUE(rb_Float(beg)));
int64_t high = double_as_int64(NIL_P(end) ? HUGE_VAL : RFLOAT_VALUE(rb_Float(end)));
- int64_t mid, org_high;
+ int64_t mid;
BSEARCH(int64_as_double_to_num, EXCL(range));
}
#endif