Skip to content

Commit

Permalink
ssl: support IO-like object as the underlying transport
Browse files Browse the repository at this point in the history
OpenSSL::SSL::SSLSocket currently requires a real IO (socket) object
because it passes the file descriptor to OpenSSL.

OpenSSL internally uses an I/O abstraction layer called BIO to interact
with the underlying socket. BIO is pluggable; the implementation can be
supplied by a user application as long as it implements the necessary
BIO functions. We can make our own BIO implementation ("BIO method")
that wraps any Ruby IO-like object using normal Ruby method calls.

Support for such an IO-like object is useful for establishing TLS
connections on top of non-OS sockets, such as another TLS connection or
an HTTP/2 tunnel.

For performance reason, this patch continues to use the original socket
BIO if the user passes a real IO object.
  • Loading branch information
rhenium committed Sep 5, 2024
1 parent 0ff6d21 commit 84ead32
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 23 deletions.
115 changes: 103 additions & 12 deletions ext/openssl/ossl_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,11 @@ static void
ossl_ssl_mark(void *ptr)
{
SSL *ssl = ptr;
rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx));
VALUE obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx);

// Ensure GC compaction won't move objects referenced by OpenSSL objects
rb_gc_mark(obj);
rb_gc_mark(rb_attr_get(obj, id_i_io));
}

static void
Expand Down Expand Up @@ -1601,13 +1605,29 @@ peeraddr_ip_str(VALUE self)
return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, (VALUE)0, rb_eSystemCallError, NULL);
}

static int
is_real_socket(VALUE io)
{
// FIXME: DO NOT MERGE
return 0;
return RB_TYPE_P(io, T_FILE);
}

/*
* call-seq:
* SSLSocket.new(io) => aSSLSocket
* SSLSocket.new(io, ctx) => aSSLSocket
*
* Creates a new SSL socket from _io_ which must be a real IO object (not an
* IO-like object that responds to read/write).
* Creates a new SSL socket from _io_ which must be an IO object
* or an IO-like object that at least implements the following methods:
*
* - <tt>write_nonblock</tt> with <tt>exception: false</tt>
* - <tt>read_nonblock</tt> with <tt>exception: false</tt>
* - <tt>wait_readable</tt>
* - <tt>wait_writable</tt>
* - <tt>flush</tt>
* - <tt>close</tt>
* - <tt>closed?</tt>
*
* If _ctx_ is provided the SSL Sockets initial params will be taken from
* the context.
Expand Down Expand Up @@ -1635,9 +1655,18 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
rb_ivar_set(self, id_i_context, v_ctx);
ossl_sslctx_setup(v_ctx);

if (rb_respond_to(io, rb_intern("nonblock=")))
rb_funcall(io, rb_intern("nonblock="), 1, Qtrue);
Check_Type(io, T_FILE);
if (is_real_socket(io)) {
rb_io_t *fptr;
GetOpenFile(io, fptr);
rb_io_set_nonblock(fptr);
}
else {
// Not meant to be a comprehensive check
if (!rb_respond_to(io, rb_intern("read_nonblock")) ||
!rb_respond_to(io, rb_intern("write_nonblock")))
rb_raise(rb_eTypeError, "io must be a real IO object or an IO-like "
"object that responds to read_nonblock and write_nonblock");
}
rb_ivar_set(self, id_i_io, io);

ssl = SSL_new(ctx);
Expand Down Expand Up @@ -1669,18 +1698,24 @@ ossl_ssl_setup(VALUE self)
{
VALUE io;
SSL *ssl;
rb_io_t *fptr;

GetSSL(self, ssl);
if (ssl_started(ssl))
return Qtrue;

io = rb_attr_get(self, id_i_io);
GetOpenFile(io, fptr);
rb_io_check_readable(fptr);
rb_io_check_writable(fptr);
if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io))))
ossl_raise(eSSLError, "SSL_set_fd");
if (is_real_socket(io)) {
rb_io_t *fptr;
GetOpenFile(io, fptr);
rb_io_check_readable(fptr);
rb_io_check_writable(fptr);
if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io))))
ossl_raise(eSSLError, "SSL_set_fd");
}
else {
BIO *bio = ossl_bio_new(io);
SSL_set_bio(ssl, bio, bio);
}

return Qtrue;
}
Expand All @@ -1691,6 +1726,38 @@ ossl_ssl_setup(VALUE self)
#define ssl_get_error(ssl, ret) SSL_get_error((ssl), (ret))
#endif

static void
check_bio_error(SSL *ssl, VALUE io, int ret)
{
if (is_real_socket(io))
return;

BIO *bio = SSL_get_rbio(ssl);
int state = ossl_bio_state(bio);
if (!state)
return;

/*
* Operation may succeed while the underlying socket reports an error in
* some cases. For example, when TLS 1.3 server tries to send a
* NewSessionTicket on a closed socket (IOW, when the client disconnects
* right after finishing a handshake).
*
* According to ssl/statem/statem_srvr.c conn_is_closed(), EPIPE and
* ECONNRESET may be ignored.
*
* FIXME BEFORE MERGE: Currently ignoring all SystemCallError.
*/
int error_code = SSL_get_error(ssl, ret);
if ((ret > 0 || error_code == SSL_ERROR_ZERO_RETURN || error_code == SSL_ERROR_SSL) &&
rb_obj_is_kind_of(rb_errinfo(), rb_eSystemCallError)) {
rb_set_errinfo(Qnil);
return;
}
ossl_clear_error();
rb_jump_tag(state);
}

static void
write_would_block(int nonblock)
{
Expand Down Expand Up @@ -1729,6 +1796,11 @@ no_exception_p(VALUE opts)
static void
io_wait_writable(VALUE io)
{
if (!is_real_socket(io)) {
if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL)))
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
return;
}
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_writable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
Expand All @@ -1743,6 +1815,11 @@ io_wait_writable(VALUE io)
static void
io_wait_readable(VALUE io)
{
if (!is_real_socket(io)) {
if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL)))
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
return;
}
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_readable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
Expand All @@ -1767,8 +1844,10 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
GetSSL(self, ssl);

VALUE io = rb_attr_get(self, id_i_io);

for (;;) {
ret = func(ssl);
check_bio_error(ssl, io, ret);

cb_state = rb_attr_get(self, ID_callback_state);
if (!NIL_P(cb_state)) {
Expand Down Expand Up @@ -1963,6 +2042,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
rb_str_locktmp(str);
for (;;) {
int nread = SSL_read(ssl, RSTRING_PTR(str), ilen);
check_bio_error(ssl, io, nread);

switch (ssl_get_error(ssl, nread)) {
case SSL_ERROR_NONE:
rb_str_unlocktmp(str);
Expand Down Expand Up @@ -2067,6 +2148,8 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts)

for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(tmp), num);
check_bio_error(ssl, io, nwritten);

switch (ssl_get_error(ssl, nwritten)) {
case SSL_ERROR_NONE:
return INT2NUM(nwritten);
Expand Down Expand Up @@ -2144,7 +2227,15 @@ ossl_ssl_stop(VALUE self)
GetSSL(self, ssl);
if (!ssl_started(ssl))
return Qnil;

ret = SSL_shutdown(ssl);

/* XXX: Suppressing errors from the underlying socket */
VALUE io = rb_attr_get(self, id_i_io);
BIO *bio = SSL_get_rbio(ssl);
if (!is_real_socket(io) && ossl_bio_state(bio))
rb_set_errinfo(Qnil);

if (ret == 1) /* Have already received close_notify */
return Qnil;
if (ret == 0) /* Sent close_notify, but we don't wait for reply */
Expand Down
43 changes: 43 additions & 0 deletions test/openssl/test_pair.rb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ def create_tcp_client(host, port)
end
end

module OpenSSL::SSLPairIOish
include OpenSSL::SSLPairM

def create_tcp_server(host, port)
Addrinfo.tcp(host, port).listen
end

class TCPSocketWrapper
def initialize(io) @io = io end
def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end
def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end
def wait_readable() @io.wait_readable end
def wait_writable() @io.wait_writable end
def flush() @io.flush end
def close() @io.close end
def closed?() @io.closed? end

# Only used within test_pair.rb
def write(*args) @io.write(*args) end
end

def create_tcp_client(host, port)
TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect)
end
end

module OpenSSL::TestEOF1M
def open_file(content)
ssl_pair { |s1, s2|
Expand Down Expand Up @@ -518,6 +544,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF1M
end

class OpenSSL::TestEOF1IOish < OpenSSL::TestCase
include OpenSSL::TestEOF
include OpenSSL::SSLPairIOish
include OpenSSL::TestEOF1M
end

class OpenSSL::TestEOF2 < OpenSSL::TestCase
include OpenSSL::TestEOF
include OpenSSL::SSLPair
Expand All @@ -530,6 +562,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF2M
end

class OpenSSL::TestEOF2IOish < OpenSSL::TestCase
include OpenSSL::TestEOF
include OpenSSL::SSLPairIOish
include OpenSSL::TestEOF2M
end

class OpenSSL::TestPair < OpenSSL::TestCase
include OpenSSL::SSLPair
include OpenSSL::TestPairM
Expand All @@ -540,4 +578,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestPairM
end

class OpenSSL::TestPairIOish < OpenSSL::TestCase
include OpenSSL::SSLPairIOish
include OpenSSL::TestPairM
end

end
70 changes: 59 additions & 11 deletions test/openssl/test_ssl.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,6 @@
if defined?(OpenSSL::SSL)

class OpenSSL::TestSSL < OpenSSL::SSLTestCase
def test_bad_socket
bad_socket = Struct.new(:sync).new
assert_raise TypeError do
socket = OpenSSL::SSL::SSLSocket.new bad_socket
# if the socket is not a T_FILE, `connect` will segv because it tries
# to get the underlying file descriptor but the API it calls assumes
# the object type is T_FILE
socket.connect
end
end

def test_ctx_options
ctx = OpenSSL::SSL::SSLContext.new

Expand Down Expand Up @@ -141,6 +130,65 @@ def test_socket_close_write
end
end

def test_synthetic_io_sanity_check
obj = Object.new
assert_raise_with_message(TypeError, /read_nonblock/) { OpenSSL::SSL::SSLSocket.new(obj) }

obj = Object.new
obj.define_singleton_method(:read_nonblock) { |*args, **kwargs| }
obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| }
assert_nothing_raised { OpenSSL::SSL::SSLSocket.new(obj) }
end

def test_synthetic_io
start_server do |port|
tcp = TCPSocket.new("127.0.0.1", port)
obj = Object.new
obj.define_singleton_method(:read_nonblock) { |maxlen, exception:|
tcp.read_nonblock(maxlen, exception: exception) }
obj.define_singleton_method(:write_nonblock) { |str, exception:|
tcp.write_nonblock(str, exception: exception) }
obj.define_singleton_method(:wait_readable) { tcp.wait_readable }
obj.define_singleton_method(:wait_writable) { tcp.wait_writable }
obj.define_singleton_method(:flush) { tcp.flush }
obj.define_singleton_method(:closed?) { tcp.closed? }

ssl = OpenSSL::SSL::SSLSocket.new(obj)
assert_same obj, ssl.to_io

ssl.connect
ssl.puts "abc"; assert_equal "abc\n", ssl.gets
ensure
ssl&.close
tcp&.close
end
end

def test_synthetic_io_write_nonblock_exception
start_server(ignore_listener_error: true) do |port|
tcp = TCPSocket.new("127.0.0.1", port)
obj = Object.new
[:read_nonblock, :wait_readable, :wait_writable, :flush, :closed?].each do |name|
obj.define_singleton_method(name) { |*args, **kwargs|
tcp.__send__(name, *args, **kwargs) }
end

# SSLSocket#connect calls write_nonblock at least twice: ClientHello and Finished
# Let's break the second call
called = 0
obj.define_singleton_method(:write_nonblock) { |*args, **kwargs|
raise "foo" if (called += 1) == 2
tcp.write_nonblock(*args, **kwargs)
}

ssl = OpenSSL::SSL::SSLSocket.new(obj)
assert_raise_with_message(RuntimeError, "foo") { ssl.connect }
ensure
ssl&.close
tcp&.close
end
end

def test_add_certificate
ctx_proc = -> ctx {
# Unset values set by start_server
Expand Down

0 comments on commit 84ead32

Please sign in to comment.