summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKazuki Yamaguchi <k@rhe.jp>2016-08-22 17:42:31 +0900
committerKazuki Yamaguchi <k@rhe.jp>2016-08-22 17:46:47 +0900
commitb099663eb81f4ef6ff8963271a04442cef2667dd (patch)
tree7cf3da2faefbc205da5f5adc4e61a60bd40a3a8e
parent5c1045ea40aa0ca76d8288fa0e91fdaa412bbb83 (diff)
downloadruby-openssl-history-b099663eb81f4ef6ff8963271a04442cef2667dd.tar.gz
pkey: allow non-BN object as the multiplier in PKey::EC::Point#mul
-rw-r--r--ext/openssl/ossl_pkey_ec.c13
-rw-r--r--test/test_pkey_ec.rb15
2 files changed, 14 insertions, 14 deletions
diff --git a/ext/openssl/ossl_pkey_ec.c b/ext/openssl/ossl_pkey_ec.c
index 5ddc4f2..e55e897 100644
--- a/ext/openssl/ossl_pkey_ec.c
+++ b/ext/openssl/ossl_pkey_ec.c
@@ -1698,11 +1698,11 @@ static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self)
Require_EC_POINT(result, point_result);
rb_scan_args(argc, argv, "12", &arg1, &arg2, &arg3);
- if (rb_obj_is_kind_of(arg1, cBN)) {
+ if (!RB_TYPE_P(arg1, T_ARRAY)) {
BIGNUM *bn = GetBNPtr(arg1);
- if (argc >= 2)
- bn_g = GetBNPtr(arg2);
+ if (!NIL_P(arg2))
+ bn_g = GetBNPtr(arg2);
if (EC_POINT_mul(group, point_result, bn_g, point_self, bn, ossl_bn_ctx) != 1)
ossl_raise(eEC_POINT, NULL);
} else {
@@ -1715,9 +1715,8 @@ static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self)
const EC_POINT **points;
const BIGNUM **bignums;
- if (!rb_obj_is_kind_of(arg1, rb_cArray) ||
- !rb_obj_is_kind_of(arg2, rb_cArray))
- ossl_raise(rb_eTypeError, "points must be array");
+ Check_Type(arg1, T_ARRAY);
+ Check_Type(arg2, T_ARRAY);
if (RARRAY_LEN(arg1) != RARRAY_LEN(arg2) + 1) /* arg2 must be 1 larger */
ossl_raise(rb_eArgError, "bns must be 1 longer than points; see the documentation");
@@ -1731,7 +1730,7 @@ static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self)
for (i = 0; i < num - 1; i++)
SafeRequire_EC_POINT(RARRAY_AREF(arg2, i), points[i + 1]);
- if (argc >= 3)
+ if (!NIL_P(arg3))
bn_g = GetBNPtr(arg3);
if (EC_POINTs_mul(group, point_result, bn_g, num, points, bignums, ossl_bn_ctx) != 1) {
diff --git a/test/test_pkey_ec.rb b/test/test_pkey_ec.rb
index b89fa38..53aa5a1 100644
--- a/test/test_pkey_ec.rb
+++ b/test/test_pkey_ec.rb
@@ -262,24 +262,25 @@ class OpenSSL::TestEC < OpenSSL::PKeyTestCase
# y^2 = x^3 + 2x + 2 over F_17
# generator is (5, 1)
group = OpenSSL::PKey::EC::Group.new(:GFp, 17, 2, 2)
+ group.point_conversion_form = :uncompressed
gen = OpenSSL::PKey::EC::Point.new(group, OpenSSL::BN.new("040501", 16))
group.set_generator(gen, 0, 0)
# 3 * (6, 3) = (16, 13)
point_a = OpenSSL::PKey::EC::Point.new(group, OpenSSL::BN.new("040603", 16))
- result_a1 = point_a.mul(3.to_bn)
+ result_a1 = point_a.mul(3)
assert_equal("04100D", result_a1.to_bn.to_s(16))
# 3 * (6, 3) + 3 * (5, 1) = (7, 6)
- result_a2 = point_a.mul(3.to_bn, 3.to_bn)
+ result_a2 = point_a.mul(3, 3)
assert_equal("040706", result_a2.to_bn.to_s(16))
# 3 * point_a = 3 * (6, 3) = (16, 13)
- result_b1 = point_a.mul([3.to_bn], [])
+ result_b1 = point_a.mul([3], [])
assert_equal("04100D", result_b1.to_bn.to_s(16))
# 3 * point_a + 2 * point_a = 3 * (6, 3) + 2 * (6, 3) = (7, 11)
- result_b1 = point_a.mul([3.to_bn, 2.to_bn], [point_a])
+ result_b1 = point_a.mul([3, 2], [point_a])
assert_equal("04070B", result_b1.to_bn.to_s(16))
# 3 * point_a + 5 * point_a.group.generator = 3 * (6, 3) + 5 * (5, 1) = (13, 10)
- result_b1 = point_a.mul([3.to_bn], [], 5)
+ result_b1 = point_a.mul([3], [], 5)
assert_equal("040D0A", result_b1.to_bn.to_s(16))
rescue OpenSSL::PKey::EC::Group::Error
# CentOS patches OpenSSL to reject curves defined over Fp where p < 256 bits
@@ -293,8 +294,8 @@ class OpenSSL::TestEC < OpenSSL::PKeyTestCase
# invalid argument
point = p256_key.public_key
assert_raise(TypeError) { point.mul(nil) }
- assert_raise(ArgumentError) { point.mul([1.to_bn], [point]) }
- assert_raise(TypeError) { point.mul([1.to_bn], nil) }
+ assert_raise(ArgumentError) { point.mul([1], [point]) }
+ assert_raise(TypeError) { point.mul([1], nil) }
assert_raise(TypeError) { point.mul([nil], []) }
end