diff --git a/src/inc/quic_datapath.h b/src/inc/quic_datapath.h index 78595b4adb..a5dd6a335c 100644 --- a/src/inc/quic_datapath.h +++ b/src/inc/quic_datapath.h @@ -282,11 +282,14 @@ typedef struct CXPLAT_QEO_CONNECTION { // // Function pointer type for datapath TCP accept callbacks. +// Any QUIC_FAILED status will reject the connection. +// Do not call CxPlatSocketDelete from this callback, it will +// crash. // typedef _IRQL_requires_max_(DISPATCH_LEVEL) _Function_class_(CXPLAT_DATAPATH_ACCEPT_CALLBACK) -void +QUIC_STATUS (CXPLAT_DATAPATH_ACCEPT_CALLBACK)( _In_ CXPLAT_SOCKET* ListenerSocket, _In_ void* ListenerContext, diff --git a/src/perf/lib/Tcp.cpp b/src/perf/lib/Tcp.cpp index 3aa4d6ad75..b71395c934 100644 --- a/src/perf/lib/Tcp.cpp +++ b/src/perf/lib/Tcp.cpp @@ -362,7 +362,7 @@ bool TcpServer::Start(const QUIC_ADDR* LocalAddress) _IRQL_requires_max_(DISPATCH_LEVEL) _Function_class_(CXPLAT_DATAPATH_ACCEPT_CALLBACK) -void +QUIC_STATUS TcpServer::AcceptCallback( _In_ CXPLAT_SOCKET* /* ListenerSocket */, _In_ void* ListenerContext, @@ -373,6 +373,7 @@ TcpServer::AcceptCallback( auto This = (TcpServer*)ListenerContext; auto Connection = new(std::nothrow) TcpConnection(This->Engine, This->SecConfig, AcceptSocket, This); *AcceptClientContext = Connection; + return QUIC_STATUS_SUCCESS; } // ############################ CONNECTION ############################ diff --git a/src/perf/lib/Tcp.h b/src/perf/lib/Tcp.h index 25a7ba25cf..5840cef8f0 100644 --- a/src/perf/lib/Tcp.h +++ b/src/perf/lib/Tcp.h @@ -155,7 +155,7 @@ class TcpServer { static _IRQL_requires_max_(DISPATCH_LEVEL) _Function_class_(CXPLAT_DATAPATH_ACCEPT_CALLBACK) - void + QUIC_STATUS AcceptCallback( _In_ CXPLAT_SOCKET* ListenerSocket, _In_ void* ListenerContext, diff --git a/src/platform/datapath_epoll.c b/src/platform/datapath_epoll.c index e416e05356..7007e41217 100644 --- a/src/platform/datapath_epoll.c +++ b/src/platform/datapath_epoll.c @@ -901,7 +901,7 @@ CxPlatSocketContextInitialize( // Only set SO_REUSEPORT on a server socket, otherwise the client could be // assigned a server port (unless it's forcing sharing). // - if ((Config->Flags & CXPLAT_SOCKET_FLAG_SHARE || Config->RemoteAddress == NULL) && + if ((Config->Flags & CXPLAT_SOCKET_FLAG_SHARE || Config->RemoteAddress == NULL) && SocketContext->Binding->Datapath->PartitionCount > 1) { // // The port is shared across processors. @@ -1552,11 +1552,14 @@ CxPlatSocketContextAcceptCompletion( CxPlatSocketContextSetEvents(&SocketContext->AcceptSocket->SocketContexts[0], EPOLL_CTL_ADD, EPOLLIN); SocketContext->AcceptSocket->SocketContexts[0].IoStarted = TRUE; - Datapath->TcpHandlers.Accept( + Status = Datapath->TcpHandlers.Accept( SocketContext->Binding, SocketContext->Binding->ClientContext, SocketContext->AcceptSocket, &SocketContext->AcceptSocket->ClientContext); + if (QUIC_FAILED(Status)) { + goto Error; + } SocketContext->AcceptSocket = NULL; diff --git a/src/platform/datapath_winuser.c b/src/platform/datapath_winuser.c index 2078533eed..6610cd8118 100644 --- a/src/platform/datapath_winuser.c +++ b/src/platform/datapath_winuser.c @@ -2667,11 +2667,15 @@ CxPlatDataPathSocketProcessAcceptCompletion( goto Error; } - Datapath->TcpHandlers.Accept( + QUIC_STATUS Status = Datapath->TcpHandlers.Accept( ListenerSocketProc->Parent, ListenerSocketProc->Parent->ClientContext, ListenerSocketProc->AcceptSocket, &ListenerSocketProc->AcceptSocket->ClientContext); + if (QUIC_FAILED(Status)) { + goto Error; + } + ListenerSocketProc->AcceptSocket = NULL; AcceptSocketProc->IoStarted = TRUE; diff --git a/src/platform/unittest/DataPathTest.cpp b/src/platform/unittest/DataPathTest.cpp index d060e810b0..6affaf85b7 100644 --- a/src/platform/unittest/DataPathTest.cpp +++ b/src/platform/unittest/DataPathTest.cpp @@ -117,8 +117,10 @@ struct TcpListenerContext { CXPLAT_SOCKET* Server; TcpClientContext ServerContext; bool Accepted : 1; + bool Reject : 1; + bool Rejected : 1; CXPLAT_EVENT AcceptEvent; - TcpListenerContext() : Server(nullptr), Accepted(false) { + TcpListenerContext() : Server(nullptr), Accepted(false), Reject{false}, Rejected{false} { CxPlatEventInitialize(&AcceptEvent, FALSE, FALSE); } ~TcpListenerContext() { @@ -317,7 +319,7 @@ struct DataPathTest : public ::testing::TestWithParam CxPlatRecvDataReturn(RecvDataChain); } - static void + static QUIC_STATUS EmptyAcceptCallback( _In_ CXPLAT_SOCKET* /* ListenerSocket */, _In_ void* /* ListenerContext */, @@ -325,6 +327,8 @@ struct DataPathTest : public ::testing::TestWithParam _Out_ void** /* ClientContext */ ) { + // If we somehow get a connection here, reject it + return QUIC_STATUS_CONNECTION_REFUSED; } static void @@ -336,7 +340,7 @@ struct DataPathTest : public ::testing::TestWithParam { } - static void + static QUIC_STATUS TcpAcceptCallback( _In_ CXPLAT_SOCKET* /* ListenerSocket */, _In_ void* Context, @@ -345,10 +349,16 @@ struct DataPathTest : public ::testing::TestWithParam ) { TcpListenerContext* ListenerContext = (TcpListenerContext*)Context; + if (ListenerContext->Reject) { + ListenerContext->Rejected = true; + CxPlatEventSet(ListenerContext->AcceptEvent); + return QUIC_STATUS_CONNECTION_REFUSED; + } ListenerContext->Server = ClientSocket; *ClientContext = &ListenerContext->ServerContext; ListenerContext->Accepted = true; CxPlatEventSet(ListenerContext->AcceptEvent); + return QUIC_STATUS_SUCCESS; } static void @@ -1077,6 +1087,43 @@ TEST_P(DataPathTest, TcpConnect) ASSERT_TRUE(CxPlatEventWaitWithTimeout(ClientContext.DisconnectEvent, 500)); } +TEST_P(DataPathTest, TcpRejectConnect) +{ + CxPlatDataPath Datapath(nullptr, &TcpRecvCallbacks); + if (!Datapath.IsSupported(CXPLAT_DATAPATH_FEATURE_TCP)) { + GTEST_SKIP_("TCP is not supported"); + } + VERIFY_QUIC_SUCCESS(Datapath.GetInitStatus()); + ASSERT_NE(nullptr, Datapath.Datapath); + + TcpListenerContext ListenerContext; + auto serverAddress = GetNewLocalAddr(); + CxPlatSocket Listener; Listener.CreateTcpListener(Datapath, &serverAddress.SockAddr, &ListenerContext); + while (Listener.GetInitStatus() == QUIC_STATUS_ADDRESS_IN_USE) { + serverAddress.SockAddr.Ipv4.sin_port = GetNextPort(); + Listener.CreateTcpListener(Datapath, &serverAddress.SockAddr, &ListenerContext); + } + VERIFY_QUIC_SUCCESS(Listener.GetInitStatus()); + ASSERT_NE(nullptr, Listener.Socket); + serverAddress.SockAddr = Listener.GetLocalAddress(); + ASSERT_NE(serverAddress.SockAddr.Ipv4.sin_port, (uint16_t)0); + + ListenerContext.Reject = true; + + TcpClientContext ClientContext; + CxPlatSocket Client; Client.CreateTcp(Datapath, nullptr, &serverAddress.SockAddr, &ClientContext); + VERIFY_QUIC_SUCCESS(Client.GetInitStatus()); + ASSERT_NE(nullptr, Client.Socket); + ASSERT_NE(Client.GetLocalAddress().Ipv4.sin_port, (uint16_t)0); + + ASSERT_TRUE(CxPlatEventWaitWithTimeout(ClientContext.ConnectEvent, 500)); + ASSERT_TRUE(CxPlatEventWaitWithTimeout(ListenerContext.AcceptEvent, 500)); + ASSERT_EQ(true, ListenerContext.Rejected); + ASSERT_EQ(nullptr, ListenerContext.Server); + + ASSERT_TRUE(CxPlatEventWaitWithTimeout(ClientContext.DisconnectEvent, 500)); +} + TEST_P(DataPathTest, TcpDisconnect) { CxPlatDataPath Datapath(nullptr, &TcpRecvCallbacks);