diff --git a/lib/openssl/ssl.rb b/lib/openssl/ssl.rb index 9f3afc2f6..6380b965a 100644 --- a/lib/openssl/ssl.rb +++ b/lib/openssl/ssl.rb @@ -448,18 +448,28 @@ class << self # # Creates a new instance of SSLSocket. # _remote\_host_ and _remote_port_ are used to open TCPSocket. - # If _context_ is provided, - # the SSL Sockets initial params will be taken from the context. # If _local\_host_ and _local\_port_ are specified, # then those parameters are used on the local end to establish the connection. + # If _context_ is provided, + # the SSL Sockets initial params will be taken from the context. # - # === Example - # ctx = OpenSSL::SSL::SSLContext.new - # sock = OpenSSL::SSL::SSLSocket.open('localhost', 443, ctx) + # === Examples + # + # sock = OpenSSL::SSL::SSLSocket.open('localhost', 443) # sock.connect # Initiates a connection to localhost:443 - def open(remote_host, remote_port, context=nil, local_host=nil, local_port=nil) + # + # with SSLContext: + # + # ctx = OpenSSL::SSL::SSLContext.new + # sock = OpenSSL::SSL::SSLSocket.open('localhost', 443, context: ctx) + # sock.connect # Initiates a connection to localhost:443 with SSLContext + def open(remote_host, remote_port, local_host=nil, local_port=nil, context: nil) sock = ::TCPSocket.open(remote_host, remote_port, local_host, local_port) - OpenSSL::SSL::SSLSocket.new(sock, context) + if context.nil? + return OpenSSL::SSL::SSLSocket.new(sock) + else + return OpenSSL::SSL::SSLSocket.new(sock, context) + end end end end diff --git a/test/test_ssl.rb b/test/test_ssl.rb index cf1cb7e90..96844ee78 100644 --- a/test/test_ssl.rb +++ b/test/test_ssl.rb @@ -56,14 +56,45 @@ def test_ssl_with_server_cert } end - def test_ssl_socket_open + def test_socket_open + start_server { |port| + begin + ssl = OpenSSL::SSL::SSLSocket.open("127.0.0.1", port) + ssl.sync_close = true + ssl.connect + + ssl.puts "abc"; assert_equal "abc\n", ssl.gets + ensure + ssl&.close + end + } + end + + def test_socket_open_with_context + start_server { |port| + begin + ctx = OpenSSL::SSL::SSLContext.new + ssl = OpenSSL::SSL::SSLSocket.open("127.0.0.1", port, context: ctx) + ssl.sync_close = true + ssl.connect + + assert_equal ssl.context, ctx + ssl.puts "abc"; assert_equal "abc\n", ssl.gets + ensure + ssl&.close + end + } + end + + def test_socket_open_with_local_address_port_context start_server { |port| begin ctx = OpenSSL::SSL::SSLContext.new - ssl = OpenSSL::SSL::SSLSocket.open("127.0.0.1", port, ctx) + ssl = OpenSSL::SSL::SSLSocket.open("127.0.0.1", port, "127.0.0.1", 8000, context: ctx) ssl.sync_close = true ssl.connect + assert_equal ssl.context, ctx ssl.puts "abc"; assert_equal "abc\n", ssl.gets ensure ssl&.close