aboutsummaryrefslogtreecommitdiffstats
path: root/test/openssl/test_pair.rb
diff options
context:
space:
mode:
Diffstat (limited to 'test/openssl/test_pair.rb')
-rw-r--r--test/openssl/test_pair.rb523
1 files changed, 523 insertions, 0 deletions
diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb
new file mode 100644
index 00000000..e9cf98df
--- /dev/null
+++ b/test/openssl/test_pair.rb
@@ -0,0 +1,523 @@
+# frozen_string_literal: true
+require_relative 'utils'
+require_relative 'ut_eof'
+
+if defined?(OpenSSL)
+
+module OpenSSL::SSLPairM
+ def setup
+ svr_dn = OpenSSL::X509::Name.parse("/DC=org/DC=ruby-lang/CN=localhost")
+ ee_exts = [
+ ["keyUsage", "keyEncipherment,digitalSignature", true],
+ ]
+ @svr_key = OpenSSL::TestUtils::Fixtures.pkey("rsa-1")
+ @svr_cert = issue_cert(svr_dn, @svr_key, 1, ee_exts, nil, nil)
+ end
+
+ def ssl_pair
+ host = "127.0.0.1"
+ tcps = create_tcp_server(host, 0)
+ port = tcps.connect_address.ip_port
+
+ th = Thread.new {
+ sctx = OpenSSL::SSL::SSLContext.new
+ sctx.cert = @svr_cert
+ sctx.key = @svr_key
+ sctx.tmp_dh_callback = proc { OpenSSL::TestUtils::Fixtures.pkey("dh-1") }
+ sctx.options |= OpenSSL::SSL::OP_NO_COMPRESSION
+ ssls = OpenSSL::SSL::SSLServer.new(tcps, sctx)
+ ns = ssls.accept
+ ssls.close
+ ns
+ }
+
+ tcpc = create_tcp_client(host, port)
+ c = OpenSSL::SSL::SSLSocket.new(tcpc)
+ c.connect
+ s = th.value
+
+ yield c, s
+ ensure
+ tcpc&.close
+ tcps&.close
+ s&.close
+ end
+end
+
+module OpenSSL::SSLPair
+ include OpenSSL::SSLPairM
+
+ def create_tcp_server(host, port)
+ TCPServer.new(host, port)
+ end
+
+ def create_tcp_client(host, port)
+ TCPSocket.new(host, port)
+ end
+end
+
+module OpenSSL::SSLPairLowlevelSocket
+ include OpenSSL::SSLPairM
+
+ def create_tcp_server(host, port)
+ Addrinfo.tcp(host, port).listen
+ end
+
+ def create_tcp_client(host, port)
+ Addrinfo.tcp(host, port).connect
+ end
+end
+
+module OpenSSL::TestEOF1M
+ def open_file(content)
+ ssl_pair { |s1, s2|
+ begin
+ th = Thread.new { s2 << content; s2.close }
+ yield s1
+ ensure
+ th&.join
+ end
+ }
+ end
+end
+
+module OpenSSL::TestEOF2M
+ def open_file(content)
+ ssl_pair { |s1, s2|
+ begin
+ th = Thread.new { s1 << content; s1.close }
+ yield s2
+ ensure
+ th&.join
+ end
+ }
+ end
+end
+
+module OpenSSL::TestPairM
+ def test_getc
+ ssl_pair {|s1, s2|
+ s1 << "a"
+ assert_equal(?a, s2.getc)
+ }
+ end
+
+ def test_gets
+ ssl_pair {|s1, s2|
+ s1 << "abc\n\n$def123ghi"
+ s1.close
+ ret = s2.gets
+ assert_equal Encoding::BINARY, ret.encoding
+ assert_equal "abc\n", ret
+ assert_equal "\n$", s2.gets("$")
+ assert_equal "def123", s2.gets(/\d+/)
+ assert_equal "ghi", s2.gets
+ assert_equal nil, s2.gets
+ }
+ end
+
+ def test_gets_eof_limit
+ ssl_pair {|s1, s2|
+ s1.write("hello")
+ s1.close # trigger EOF
+ assert_match "hello", s2.gets("\n", 6), "[ruby-core:70149] [Bug #11400]"
+ }
+ end
+
+ def test_readpartial
+ ssl_pair {|s1, s2|
+ s2.write "a\nbcd"
+ assert_equal("a\n", s1.gets)
+ result = String.new
+ result << s1.readpartial(10) until result.length == 3
+ assert_equal("bcd", result)
+ s2.write "efg"
+ result = String.new
+ result << s1.readpartial(10) until result.length == 3
+ assert_equal("efg", result)
+ s2.close
+ assert_raise(EOFError) { s1.readpartial(10) }
+ assert_raise(EOFError) { s1.readpartial(10) }
+ assert_equal("", s1.readpartial(0))
+ }
+ end
+
+ def test_readall
+ ssl_pair {|s1, s2|
+ s2.close
+ assert_equal("", s1.read)
+ }
+ end
+
+ def test_readline
+ ssl_pair {|s1, s2|
+ s2.close
+ assert_raise(EOFError) { s1.readline }
+ }
+ end
+
+ def test_puts_meta
+ ssl_pair {|s1, s2|
+ begin
+ old = $/
+ $/ = '*'
+ s1.puts 'a'
+ ensure
+ $/ = old
+ end
+ s1.close
+ assert_equal("a\n", s2.read)
+ }
+ end
+
+ def test_puts_empty
+ ssl_pair {|s1, s2|
+ s1.puts
+ s1.close
+ assert_equal("\n", s2.read)
+ }
+ end
+
+ def test_multibyte_read_write
+ # German a umlaut
+ auml = [%w{ C3 A4 }.join('')].pack('H*')
+ auml.force_encoding(Encoding::UTF_8)
+ bsize = auml.bytesize
+
+ ssl_pair { |s1, s2|
+ assert_equal bsize, s1.write(auml)
+ read = s2.read(bsize)
+ assert_equal Encoding::ASCII_8BIT, read.encoding
+ assert_equal bsize, read.bytesize
+ assert_equal auml, read.force_encoding(Encoding::UTF_8)
+
+ s1.puts(auml)
+ read = s2.gets
+ assert_equal Encoding::ASCII_8BIT, read.encoding
+ assert_equal bsize + 1, read.bytesize
+ assert_equal auml + "\n", read.force_encoding(Encoding::UTF_8)
+ }
+ end
+
+ def test_read_nonblock
+ ssl_pair {|s1, s2|
+ err = nil
+ assert_raise(OpenSSL::SSL::SSLErrorWaitReadable) {
+ begin
+ s2.read_nonblock(10)
+ ensure
+ err = $!
+ end
+ }
+ assert_kind_of(IO::WaitReadable, err)
+ s1.write "abc\ndef\n"
+ IO.select([s2])
+ assert_equal("ab", s2.read_nonblock(2))
+ assert_equal("c\n", s2.gets)
+ ret = nil
+ assert_nothing_raised("[ruby-core:20298]") { ret = s2.read_nonblock(10) }
+ assert_equal("def\n", ret)
+ s1.close
+ IO.select([s2])
+ assert_raise(EOFError) { s2.read_nonblock(10) }
+ }
+ end
+
+ def test_read_nonblock_no_exception
+ ssl_pair {|s1, s2|
+ assert_equal :wait_readable, s2.read_nonblock(10, exception: false)
+ s1.write "abc\ndef\n"
+ IO.select([s2])
+ assert_equal("ab", s2.read_nonblock(2, exception: false))
+ assert_equal("c\n", s2.gets)
+ ret = nil
+ assert_nothing_raised("[ruby-core:20298]") { ret = s2.read_nonblock(10, exception: false) }
+ assert_equal("def\n", ret)
+ s1.close
+ IO.select([s2])
+ assert_equal(nil, s2.read_nonblock(10, exception: false))
+ }
+ end
+
+ def test_read_with_outbuf
+ ssl_pair { |s1, s2|
+ s1.write("abc\n")
+ buf = String.new
+ ret = s2.read(2, buf)
+ assert_same ret, buf
+ assert_equal "ab", ret
+
+ buf = +"garbage"
+ ret = s2.read(2, buf)
+ assert_same ret, buf
+ assert_equal "c\n", ret
+
+ buf = +"garbage"
+ assert_equal :wait_readable, s2.read_nonblock(100, buf, exception: false)
+ assert_equal "", buf
+
+ s1.close
+ buf = +"garbage"
+ assert_equal nil, s2.read(100, buf)
+ assert_equal "", buf
+ }
+ end
+
+ def test_write_nonblock
+ ssl_pair {|s1, s2|
+ assert_equal 3, s1.write_nonblock("foo")
+ assert_equal "foo", s2.read(3)
+
+ data = "x" * 16384
+ written = 0
+ while true
+ begin
+ written += s1.write_nonblock(data)
+ rescue IO::WaitWritable, IO::WaitReadable
+ break
+ end
+ end
+ assert written > 0
+ assert_equal written, s2.read(written).bytesize
+ }
+ end
+
+ def test_write_nonblock_no_exceptions
+ ssl_pair {|s1, s2|
+ assert_equal 3, s1.write_nonblock("foo", exception: false)
+ assert_equal "foo", s2.read(3)
+
+ data = "x" * 16384
+ written = 0
+ while true
+ case ret = s1.write_nonblock(data, exception: false)
+ when :wait_readable, :wait_writable
+ break
+ else
+ written += ret
+ end
+ end
+ assert written > 0
+ assert_equal written, s2.read(written).bytesize
+ }
+ end
+
+ def test_write_nonblock_with_buffered_data
+ ssl_pair {|s1, s2|
+ s1.write "foo"
+ s1.write_nonblock("bar")
+ s1.write "baz"
+ s1.close
+ assert_equal("foobarbaz", s2.read)
+ }
+ end
+
+ def test_write_nonblock_with_buffered_data_no_exceptions
+ ssl_pair {|s1, s2|
+ s1.write "foo"
+ s1.write_nonblock("bar", exception: false)
+ s1.write "baz"
+ s1.close
+ assert_equal("foobarbaz", s2.read)
+ }
+ end
+
+ def test_write_nonblock_retry
+ ssl_pair {|s1, s2|
+ # fill up a socket so we hit EAGAIN
+ written = String.new
+ n = 0
+ buf = 'a' * 4099
+ case ret = s1.write_nonblock(buf, exception: false)
+ when :wait_readable then break
+ when :wait_writable then break
+ when Integer
+ written << buf
+ n += ret
+ exp = buf.bytesize
+ if ret != exp
+ buf = buf.byteslice(ret, exp - ret)
+ end
+ end while true
+ assert_kind_of Symbol, ret
+
+ # make more space for subsequent write:
+ readed = s2.read(n)
+ assert_equal written, readed
+
+ # this fails if SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER is missing:
+ buf2 = Marshal.load(Marshal.dump(buf))
+ assert_kind_of Integer, s1.write_nonblock(buf2, exception: false)
+ }
+ end
+
+ def test_write_zero
+ ssl_pair {|s1, s2|
+ assert_equal 0, s2.write_nonblock('', exception: false)
+ assert_kind_of Symbol, s1.read_nonblock(1, exception: false)
+ assert_equal 0, s2.syswrite('')
+ assert_kind_of Symbol, s1.read_nonblock(1, exception: false)
+ assert_equal 0, s2.write('')
+ assert_kind_of Symbol, s1.read_nonblock(1, exception: false)
+ }
+ end
+
+ def test_write_multiple_arguments
+ ssl_pair {|s1, s2|
+ str1 = "foo"; str2 = "bar"
+ assert_equal 6, s1.write(str1, str2)
+ s1.close
+ assert_equal "foobar", s2.read
+ }
+ end
+
+ def test_partial_tls_record_read_nonblock
+ ssl_pair { |s1, s2|
+ # the beginning of a TLS record
+ s1.io.write("\x17")
+ # should raise a IO::WaitReadable since a full TLS record is not available
+ # for reading
+ assert_raise(IO::WaitReadable) { s2.read_nonblock(1) }
+ }
+ end
+
+ def tcp_pair
+ host = "127.0.0.1"
+ serv = TCPServer.new(host, 0)
+ port = serv.connect_address.ip_port
+ sock1 = TCPSocket.new(host, port)
+ sock2 = serv.accept
+ serv.close
+ [sock1, sock2]
+ ensure
+ serv.close if serv && !serv.closed?
+ end
+
+ def test_connect_accept_nonblock_no_exception
+ ctx2 = OpenSSL::SSL::SSLContext.new
+ ctx2.cert = @svr_cert
+ ctx2.key = @svr_key
+ ctx2.tmp_dh_callback = proc { OpenSSL::TestUtils::Fixtures.pkey("dh-1") }
+
+ sock1, sock2 = tcp_pair
+
+ s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2)
+ accepted = s2.accept_nonblock(exception: false)
+ assert_equal :wait_readable, accepted
+
+ ctx1 = OpenSSL::SSL::SSLContext.new
+ s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1)
+ th = Thread.new do
+ rets = []
+ begin
+ rv = s1.connect_nonblock(exception: false)
+ rets << rv
+ case rv
+ when :wait_writable
+ IO.select(nil, [s1], nil, 5)
+ when :wait_readable
+ IO.select([s1], nil, nil, 5)
+ end
+ end until rv == s1
+ rets
+ end
+
+ until th.join(0.01)
+ accepted = s2.accept_nonblock(exception: false)
+ assert_include([s2, :wait_readable, :wait_writable ], accepted)
+ end
+
+ rets = th.value
+ assert_instance_of Array, rets
+ rets.each do |rv|
+ assert_include([s1, :wait_readable, :wait_writable ], rv)
+ end
+ ensure
+ th.join if th
+ s1.close if s1
+ s2.close if s2
+ sock1.close if sock1
+ sock2.close if sock2
+ accepted.close if accepted.respond_to?(:close)
+ end
+
+ def test_connect_accept_nonblock
+ ctx = OpenSSL::SSL::SSLContext.new
+ ctx.cert = @svr_cert
+ ctx.key = @svr_key
+ ctx.tmp_dh_callback = proc { OpenSSL::TestUtils::Fixtures.pkey("dh-1") }
+
+ sock1, sock2 = tcp_pair
+
+ th = Thread.new {
+ s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx)
+ 5.times {
+ begin
+ break s2.accept_nonblock
+ rescue IO::WaitReadable
+ IO.select([s2], nil, nil, 1)
+ rescue IO::WaitWritable
+ IO.select(nil, [s2], nil, 1)
+ end
+ sleep 0.2
+ }
+ }
+
+ s1 = OpenSSL::SSL::SSLSocket.new(sock1)
+ 5.times {
+ begin
+ break s1.connect_nonblock
+ rescue IO::WaitReadable
+ IO.select([s1], nil, nil, 1)
+ rescue IO::WaitWritable
+ IO.select(nil, [s1], nil, 1)
+ end
+ sleep 0.2
+ }
+
+ s2 = th.value
+
+ s1.print "a\ndef"
+ assert_equal("a\n", s2.gets)
+ ensure
+ sock1&.close
+ sock2&.close
+ th&.join
+ end
+end
+
+class OpenSSL::TestEOF1 < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPair
+ include OpenSSL::TestEOF1M
+end
+
+class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairLowlevelSocket
+ include OpenSSL::TestEOF1M
+end
+
+class OpenSSL::TestEOF2 < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPair
+ include OpenSSL::TestEOF2M
+end
+
+class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairLowlevelSocket
+ include OpenSSL::TestEOF2M
+end
+
+class OpenSSL::TestPair < OpenSSL::TestCase
+ include OpenSSL::SSLPair
+ include OpenSSL::TestPairM
+end
+
+class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase
+ include OpenSSL::SSLPairLowlevelSocket
+ include OpenSSL::TestPairM
+end
+
+end