aboutsummaryrefslogtreecommitdiffstats
path: root/ext/openssl/ossl_pkey_ec.c
diff options
context:
space:
mode:
Diffstat (limited to 'ext/openssl/ossl_pkey_ec.c')
-rw-r--r--ext/openssl/ossl_pkey_ec.c118
1 files changed, 64 insertions, 54 deletions
diff --git a/ext/openssl/ossl_pkey_ec.c b/ext/openssl/ossl_pkey_ec.c
index c93e3cfb99..8f6edfab26 100644
--- a/ext/openssl/ossl_pkey_ec.c
+++ b/ext/openssl/ossl_pkey_ec.c
@@ -1500,74 +1500,84 @@ static VALUE ossl_ec_point_to_bn(VALUE self)
/*
* call-seq:
- * point.mul(bn) => point
- * point.mul(bn, bn) => point
- * point.mul([bn], [point]) => point
- * point.mul([bn], [point], bn) => point
+ * point.mul(bn1 [, bn2]) => point
+ * point.mul(bns, points [, bn2]) => point
+ *
+ * Performs elliptic curve point multiplication.
+ *
+ * The first form calculates <tt>bn1 * point + bn2 * G</tt>, where +G+ is the
+ * generator of the group of +point+. +bn2+ may be ommitted, and in that case,
+ * the result is just <tt>bn1 * point</tt>.
+ *
+ * The second form calculates <tt>bns[0] * point + bns[1] * points[0] + ...
+ * + bns[-1] * points[-1] + bn2 * G</tt>. +bn2+ may be ommitted. +bns+ must be
+ * an array of OpenSSL::BN. +points+ must be an array of
+ * OpenSSL::PKey::EC::Point. Please note that <tt>points[0]</tt> is not
+ * multiplied by <tt>bns[0]</tt>, but <tt>bns[1]</tt>.
*/
static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self)
{
- EC_POINT *point1, *point2;
+ EC_POINT *point_self, *point_result;
const EC_GROUP *group;
VALUE group_v = rb_iv_get(self, "@group");
- VALUE bn_v1, bn_v2, r, points_v;
- BIGNUM *bn1 = NULL, *bn2 = NULL;
+ VALUE arg1, arg2, arg3, result;
+ const BIGNUM *bn_g = NULL;
- Require_EC_POINT(self, point1);
+ Require_EC_POINT(self, point_self);
SafeRequire_EC_GROUP(group_v, group);
- r = rb_obj_alloc(cEC_POINT);
- ossl_ec_point_initialize(1, &group_v, r);
- Require_EC_POINT(r, point2);
+ result = rb_obj_alloc(cEC_POINT);
+ ossl_ec_point_initialize(1, &group_v, result);
+ Require_EC_POINT(result, point_result);
- argc = rb_scan_args(argc, argv, "12", &bn_v1, &points_v, &bn_v2);
+ rb_scan_args(argc, argv, "12", &arg1, &arg2, &arg3);
+ if (rb_obj_is_kind_of(arg1, cBN)) {
+ BIGNUM *bn = GetBNPtr(arg1);
+ if (argc >= 2)
+ bn_g = GetBNPtr(arg2);
- if (rb_obj_is_kind_of(bn_v1, cBN)) {
- bn1 = GetBNPtr(bn_v1);
- if (argc >= 2) {
- bn2 = GetBNPtr(points_v);
- }
- if (EC_POINT_mul(group, point2, bn2, point1, bn1, ossl_bn_ctx) != 1)
- ossl_raise(eEC_POINT, "Multiplication failed");
+ if (EC_POINT_mul(group, point_result, bn_g, point_self, bn, ossl_bn_ctx) != 1)
+ ossl_raise(eEC_POINT, NULL);
} else {
- size_t i, points_len, bignums_len;
- const EC_POINT **points;
- const BIGNUM **bignums;
-
- Check_Type(bn_v1, T_ARRAY);
- bignums_len = RARRAY_LEN(bn_v1);
- bignums = (const BIGNUM **)OPENSSL_malloc(bignums_len * (int)sizeof(BIGNUM *));
-
- for (i = 0; i < bignums_len; ++i) {
- bignums[i] = GetBNPtr(rb_ary_entry(bn_v1, i));
- }
-
- if (!rb_obj_is_kind_of(points_v, rb_cArray)) {
- OPENSSL_free((void *)bignums);
- rb_raise(rb_eTypeError, "Argument2 must be an array");
- }
-
- rb_ary_unshift(points_v, self);
- points_len = RARRAY_LEN(points_v);
- points = (const EC_POINT **)OPENSSL_malloc(points_len * (int)sizeof(EC_POINT *));
-
- for (i = 0; i < points_len; ++i) {
- Get_EC_POINT(rb_ary_entry(points_v, i), points[i]);
- }
+ /*
+ * bignums | arg1[0] | arg1[1] | arg1[2] | ...
+ * points | self | arg2[0] | arg2[1] | ...
+ */
+ int i, num;
+ VALUE tmp_p, tmp_b;
+ const EC_POINT **points;
+ const BIGNUM **bignums;
+
+ if (!rb_obj_is_kind_of(arg1, rb_cArray) ||
+ !rb_obj_is_kind_of(arg2, rb_cArray))
+ ossl_raise(rb_eTypeError, "points must be array");
+ if (RARRAY_LEN(arg1) != RARRAY_LEN(arg2) + 1) /* arg2 must be 1 larger */
+ ossl_raise(rb_eArgError, "bns must be 1 longer than points; see the documentation");
+
+ num = RARRAY_LEN(arg1);
+ bignums = ALLOCV_N(const BIGNUM *, tmp_b, num);
+ for (i = 0; i < num; i++)
+ bignums[i] = GetBNPtr(RARRAY_AREF(arg1, i));
+
+ points = ALLOCV_N(const EC_POINT *, tmp_p, num);
+ points[0] = point_self; /* self */
+ for (i = 0; i < num - 1; i++)
+ SafeRequire_EC_POINT(RARRAY_AREF(arg2, i), points[i + 1]);
+
+ if (argc >= 3)
+ bn_g = GetBNPtr(arg3);
+
+ if (EC_POINTs_mul(group, point_result, bn_g, num, points, bignums, ossl_bn_ctx) != 1) {
+ ALLOCV_END(tmp_b);
+ ALLOCV_END(tmp_p);
+ ossl_raise(eEC_POINT, NULL);
+ }
- if (argc >= 3) {
- bn2 = GetBNPtr(bn_v2);
- }
- if (EC_POINTs_mul(group, point2, bn2, points_len, points, bignums, ossl_bn_ctx) != 1) {
- OPENSSL_free((void *)bignums);
- OPENSSL_free((void *)points);
- ossl_raise(eEC_POINT, "Multiplication failed");
- }
- OPENSSL_free((void *)bignums);
- OPENSSL_free((void *)points);
+ ALLOCV_END(tmp_b);
+ ALLOCV_END(tmp_p);
}
- return r;
+ return result;
}
static void no_copy(VALUE klass)