aboutsummaryrefslogtreecommitdiffstats
path: root/test/utils.rb
blob: cf0e0e51a13bc7ddc5457863352322b6d12272a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
module Minitest::Assertions
  def assert_http_error(klass, type, &blk)
    begin
      blk.call
    rescue klass => e
      assert_equal(type, e.http2_error_type)
    else
      flunk "#{klass.name} type: #{type} expected but nothing was raised."
    end
  end

  def assert_connection_error(type, &blk)
    assert_http_error(Plum::ConnectionError, type, &blk)
  end

  def assert_stream_error(type, &blk)
    assert_http_error(Plum::StreamError, type, &blk)
  end

  def refute_raises(&blk)
    begin
      blk.call
    rescue
      a = $!
    else
      a = nil
    end
    assert(!a, "No exceptions expected but raised: #{a}:\n#{a && a.backtrace.join("\n")}")
  end
end

module ServerTestUtils
  private
  def open_server_connection
    io = StringIO.new
    @_con = ServerConnection.new(io)
    @_con << ServerConnection::CLIENT_CONNECTION_PREFACE
    @_con << Frame.new(type: :settings, stream_id: 0).assemble
    if block_given?
      yield @_con
    else
      @_con
    end
  end

  def open_new_stream(state = :idle)
    open_server_connection do |con|
      @_stream = con.instance_eval { new_stream(3, state: state) }
      if block_given?
        yield @_stream
      else
        @_stream
      end
    end
  end

  def sent_frames(con = nil)
    resp = (con || @_con).socket.string.dup
    frames = []
    while f = Frame.parse!(resp)
      frames << f
    end
    frames
  end

  def start_server(&blk)
    ctx = OpenSSL::SSL::SSLContext.new
    ctx.alpn_select_cb = -> protocols { "h2" }
    ctx.cert = OpenSSL::X509::Certificate.new File.read(File.expand_path("../server.crt", __FILE__))
    ctx.key = OpenSSL::PKey::RSA.new File.read(File.expand_path("../server.key", __FILE__))
    tcp_server = TCPServer.new("127.0.0.1", LISTEN_PORT)
    ssl_server = OpenSSL::SSL::SSLServer.new(tcp_server, ctx)

    plum = Plum::ServerConnection.new(nil)

    server_thread = Thread.new {
      begin
        timeout(3) {
          sock = ssl_server.accept
          plum.instance_eval { @socket = sock }
          plum.start
        }
      rescue TimeoutError
        flunk "server timeout"
      ensure
        tcp_server.close
      end
    }
    client_thread = Thread.new {
      begin
        timeout(3) { blk.call(plum) }
      rescue TimeoutError
        flunk "client timeout"
      end
    }
    client_thread.join
    server_thread.join
  end

  # Connect to server and returns client socket
  def start_client(ctx = nil, &blk)
    ctx ||= OpenSSL::SSL::SSLContext.new.tap {|ctx|
      ctx.alpn_protocols = ["h2"]
    }

    sock = TCPSocket.new("127.0.0.1", LISTEN_PORT)
    ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx)
    ssl.connect
    blk.call(ssl)
  ensure
    ssl.close
  end
end