aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKazuki Yamaguchi <k@rhe.jp>2016-12-08 15:23:39 +0900
committerKazuki Yamaguchi <k@rhe.jp>2016-12-23 13:14:09 +0900
commitfaaa3021385b5432dc960dfc8ca55d1b2fc89d3b (patch)
tree18695137e7351a5830b882f5626ac433a1b526d1
parent9b30cc419be9e1485ce535463a159c9ce21189d4 (diff)
downloadruby-openssl-faaa3021385b5432dc960dfc8ca55d1b2fc89d3b.tar.gz
bn: implement unary {plus,minus} operators for OpenSSL::BN
For consistency with Numeric. Not sure why they aren't currently; maybe they were simply forgotten.
-rw-r--r--ext/openssl/ossl_bn.c34
-rw-r--r--test/test_bn.rb7
2 files changed, 41 insertions, 0 deletions
diff --git a/ext/openssl/ossl_bn.c b/ext/openssl/ossl_bn.c
index 4e371cb2..1afebf44 100644
--- a/ext/openssl/ossl_bn.c
+++ b/ext/openssl/ossl_bn.c
@@ -856,6 +856,37 @@ ossl_bn_copy(VALUE self, VALUE other)
return self;
}
+/*
+ * call-seq:
+ * +bn -> aBN
+ */
+static VALUE
+ossl_bn_uplus(VALUE self)
+{
+ return self;
+}
+
+/*
+ * call-seq:
+ * -bn -> aBN
+ */
+static VALUE
+ossl_bn_uminus(VALUE self)
+{
+ VALUE obj;
+ BIGNUM *bn1, *bn2;
+
+ GetBN(self, bn1);
+ obj = NewBN(cBN);
+ bn2 = BN_dup(bn1);
+ if (!bn2)
+ ossl_raise(eBNError, "BN_dup");
+ SetBN(obj, bn2);
+ BN_set_negative(bn2, !BN_is_negative(bn2));
+
+ return obj;
+}
+
#define BIGNUM_CMP(func) \
static VALUE \
ossl_bn_##func(VALUE self, VALUE other) \
@@ -1068,6 +1099,9 @@ Init_ossl_bn(void)
rb_define_method(cBN, "num_bits", ossl_bn_num_bits, 0);
/* num_bits_word */
+ rb_define_method(cBN, "+@", ossl_bn_uplus, 0);
+ rb_define_method(cBN, "-@", ossl_bn_uminus, 0);
+
rb_define_method(cBN, "+", ossl_bn_add, 1);
rb_define_method(cBN, "-", ossl_bn_sub, 1);
rb_define_method(cBN, "*", ossl_bn_mul, 1);
diff --git a/test/test_bn.rb b/test/test_bn.rb
index 4ee7f4e0..ac4334fd 100644
--- a/test/test_bn.rb
+++ b/test/test_bn.rb
@@ -116,6 +116,13 @@ class OpenSSL::TestBN < OpenSSL::TestCase
assert_raise(OpenSSL::BNError) { 1.to_bn / 0 }
end
+ def test_unary_plus_minus
+ assert_equal(999, +@e1)
+ assert_equal(-999, +@e2)
+ assert_equal(-999, -@e1)
+ assert_equal(+999, -@e2)
+ end
+
def test_mod
assert_equal(1, 1.to_bn % 2)
assert_equal(0, 2.to_bn % 1)