diff --git a/src/core/binding.c b/src/core/binding.c index d8f77fdd5f..68d538f8fd 100644 --- a/src/core/binding.c +++ b/src/core/binding.c @@ -1781,7 +1781,7 @@ QuicBindingUnreachable( } _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void QuicBindingSend( _In_ QUIC_BINDING* Binding, _In_ const CXPLAT_ROUTE* Route, @@ -1790,8 +1790,6 @@ QuicBindingSend( _In_ uint32_t DatagramsToSend ) { - QUIC_STATUS Status; - #if QUIC_TEST_DATAPATH_HOOKS_ENABLED QUIC_TEST_DATAPATH_HOOKS* Hooks = MsQuicLib.TestDatapathHooks; if (Hooks != NULL) { @@ -1810,35 +1808,12 @@ QuicBindingSend( "[bind][%p] Test dropped packet", Binding); CxPlatSendDataFree(SendData); - Status = QUIC_STATUS_SUCCESS; } else { - Status = - CxPlatSocketSend( - Binding->Socket, - &RouteCopy, - SendData); - if (QUIC_FAILED(Status)) { - QuicTraceLogWarning( - BindingSendFailed, - "[bind][%p] Send failed, 0x%x", - Binding, - Status); - } + CxPlatSocketSend(Binding->Socket, &RouteCopy, SendData); } } else { #endif - Status = - CxPlatSocketSend( - Binding->Socket, - Route, - SendData); - if (QUIC_FAILED(Status)) { - QuicTraceLogWarning( - BindingSendFailed, - "[bind][%p] Send failed, 0x%x", - Binding, - Status); - } + CxPlatSocketSend(Binding->Socket, Route, SendData); #if QUIC_TEST_DATAPATH_HOOKS_ENABLED } #endif @@ -1846,6 +1821,4 @@ QuicBindingSend( QuicPerfCounterAdd(QUIC_PERF_COUNTER_UDP_SEND, DatagramsToSend); QuicPerfCounterAdd(QUIC_PERF_COUNTER_UDP_SEND_BYTES, BytesToSend); QuicPerfCounterIncrement(QUIC_PERF_COUNTER_UDP_SEND_CALLS); - - return Status; } diff --git a/src/core/binding.h b/src/core/binding.h index 1c5b94e42b..bdf61c3cb6 100644 --- a/src/core/binding.h +++ b/src/core/binding.h @@ -456,7 +456,7 @@ QuicBindingReleaseStatelessOperation( // the duration of the send operation. // _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void QuicBindingSend( _In_ QUIC_BINDING* Binding, _In_ const CXPLAT_ROUTE* Route, diff --git a/src/inc/quic_datapath.h b/src/inc/quic_datapath.h index 929becd130..d8dddc4036 100644 --- a/src/inc/quic_datapath.h +++ b/src/inc/quic_datapath.h @@ -722,7 +722,7 @@ CxPlatSendDataIsFull( // Sends the data over the socket. // _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void CxPlatSocketSend( _In_ CXPLAT_SOCKET* Socket, _In_ const CXPLAT_ROUTE* Route, diff --git a/src/perf/lib/Tcp.cpp b/src/perf/lib/Tcp.cpp index 62158a142e..a41a348e33 100644 --- a/src/perf/lib/Tcp.cpp +++ b/src/perf/lib/Tcp.cpp @@ -519,10 +519,11 @@ TcpConnection::ConnectCallback( Connected); if (Connected) { This->StartTls = true; - } else { + This->Queue(); + } else if (!This->Shutdown) { This->Shutdown = true; + This->Queue(); } - This->Queue(); } _IRQL_requires_max_(DISPATCH_LEVEL) @@ -566,22 +567,30 @@ void TcpConnection::SendCompleteCallback( _In_ CXPLAT_SOCKET* /* Socket */, _In_ void* Context, - _In_ QUIC_STATUS /* Status */, + _In_ QUIC_STATUS Status, _In_ uint32_t ByteCount ) { TcpConnection* This = (TcpConnection*)Context; + bool QueueWork = false; QuicTraceLogVerbose( PerfTcpSendCompleteCallback, - "[perf][tcp][%p] SendComplete callback", - This); + "[perf][tcp][%p] SendComplete callback, %u", + This, + (uint32_t)Status); CxPlatDispatchLockAcquire(&This->Lock); - if (This->TotalSendCompleteOffset != UINT64_MAX) { + if (QUIC_FAILED(Status)) { + if (!This->Shutdown) { + This->Shutdown = true; + QueueWork = true; + } + } else if (This->TotalSendCompleteOffset != UINT64_MAX) { This->TotalSendCompleteOffset += ByteCount; This->IndicateSendComplete = true; + QueueWork = true; } CxPlatDispatchLockRelease(&This->Lock); - if (This->TotalSendCompleteOffset != UINT64_MAX) { + if (QueueWork) { This->Queue(); } } @@ -656,10 +665,7 @@ void TcpConnection::Process() } } if (BatchedSendData && !Shutdown) { - if (QUIC_FAILED( - CxPlatSocketSend(Socket, &Route, BatchedSendData))) { - Shutdown = true; - } + CxPlatSocketSend(Socket, &Route, BatchedSendData); BatchedSendData = nullptr; } if (IndicateSendComplete) { @@ -754,7 +760,7 @@ bool TcpConnection::ProcessTls(const uint8_t* Buffer, uint32_t BufferLength) IndicateConnect = true; } - while (BaseOffset < TlsState.BufferTotalLength) { + while (!Shutdown && BaseOffset < TlsState.BufferTotalLength) { if (TlsState.BufferOffsetHandshake) { if (BaseOffset < TlsState.BufferOffsetHandshake) { uint16_t Length = (uint16_t)(TlsState.BufferOffsetHandshake - BaseOffset); @@ -808,7 +814,9 @@ bool TcpConnection::SendTlsData(const uint8_t* Buffer, uint16_t BufferLength, ui } SendBuffer->Length = sizeof(TcpFrame) + Frame->Length + CXPLAT_ENCRYPTION_OVERHEAD; - return FinalizeSendBuffer(SendBuffer); + FinalizeSendBuffer(SendBuffer); + + return true; } bool TcpConnection::ProcessReceive() @@ -1008,11 +1016,9 @@ bool TcpConnection::ProcessSend() } SendBuffer->Length = sizeof(TcpFrame) + Frame->Length + CXPLAT_ENCRYPTION_OVERHEAD; - if (!FinalizeSendBuffer(SendBuffer)) { - return false; - } + FinalizeSendBuffer(SendBuffer); - } while (NextSendData->Length > Offset); + } while (!Shutdown && NextSendData->Length > Offset); NextSendData->Offset = TotalSendOffset; NextSendData = NextSendData->Next; @@ -1088,18 +1094,14 @@ void TcpConnection::FreeSendBuffer(QUIC_BUFFER* SendBuffer) CxPlatSendDataFreeBuffer(BatchedSendData, SendBuffer); } -bool TcpConnection::FinalizeSendBuffer(QUIC_BUFFER* SendBuffer) +void TcpConnection::FinalizeSendBuffer(QUIC_BUFFER* SendBuffer) { TotalSendOffset += SendBuffer->Length; if (SendBuffer->Length != TLS_BLOCK_SIZE || CxPlatSendDataIsFull(BatchedSendData)) { - auto Status = CxPlatSocketSend(Socket, &Route, BatchedSendData); + CxPlatSocketSend(Socket, &Route, BatchedSendData); BatchedSendData = nullptr; - if (QUIC_FAILED(Status)) { - return false; - } } - return true; } bool TcpConnection::Send(TcpSendData* Data) diff --git a/src/perf/lib/Tcp.h b/src/perf/lib/Tcp.h index b6663f422c..25a7ba25cf 100644 --- a/src/perf/lib/Tcp.h +++ b/src/perf/lib/Tcp.h @@ -268,7 +268,7 @@ class TcpConnection { bool EncryptFrame(TcpFrame* Frame); QUIC_BUFFER* NewSendBuffer(); void FreeSendBuffer(QUIC_BUFFER* SendBuffer); - bool FinalizeSendBuffer(QUIC_BUFFER* SendBuffer); + void FinalizeSendBuffer(QUIC_BUFFER* SendBuffer); bool TryAddRef() { return CxPlatRefIncrementNonZero(&Ref, 1) != FALSE; } void Release() { if (CxPlatRefDecrement(&Ref)) delete this; } public: diff --git a/src/platform/datapath_epoll.c b/src/platform/datapath_epoll.c index f6cd826e7e..7bfbea36a2 100644 --- a/src/platform/datapath_epoll.c +++ b/src/platform/datapath_epoll.c @@ -2222,7 +2222,7 @@ CxPlatSendDataSend( _In_ CXPLAT_SEND_DATA* SendData ); -QUIC_STATUS +void SocketSend( _In_ CXPLAT_SOCKET* Socket, _In_ const CXPLAT_ROUTE* Route, @@ -2272,7 +2272,7 @@ SocketSend( &SocketContext->FlushTxSqe.Sqe, &SocketContext->FlushTxSqe)); } - return QUIC_STATUS_SUCCESS; + return; } // @@ -2299,8 +2299,6 @@ SocketSend( } CxPlatSendDataFree(SendData); } - - return Status; } // diff --git a/src/platform/datapath_kqueue.c b/src/platform/datapath_kqueue.c index ba70158ad5..464d122550 100644 --- a/src/platform/datapath_kqueue.c +++ b/src/platform/datapath_kqueue.c @@ -2089,7 +2089,7 @@ CxPlatSocketSendInternal( return Status; } -QUIC_STATUS +void CxPlatSocketSend( _In_ CXPLAT_SOCKET* Socket, _In_ const CXPLAT_ROUTE* Route, @@ -2098,18 +2098,12 @@ CxPlatSocketSend( { UNREFERENCED_PARAMETER(Socket); CXPLAT_DBG_ASSERT(Route->Queue); - CXPLAT_SOCKET_CONTEXT* SocketContext = Route->Queue; - QUIC_STATUS Status = - CxPlatSocketSendInternal( - SocketContext, - &Route->LocalAddress, - &Route->RemoteAddress, - SendData, - FALSE); - if (Status == QUIC_STATUS_PENDING) { - Status = QUIC_STATUS_SUCCESS; - } - return Status; + CxPlatSocketSendInternal( + Route->Queue, + &Route->LocalAddress, + &Route->RemoteAddress, + SendData, + FALSE); } uint16_t diff --git a/src/platform/datapath_winkernel.c b/src/platform/datapath_winkernel.c index d8aa7e9a3b..4b034737db 100644 --- a/src/platform/datapath_winkernel.c +++ b/src/platform/datapath_winkernel.c @@ -3002,7 +3002,7 @@ CxPlatSocketPrepareSendData( } _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void CxPlatSocketSend( _In_ CXPLAT_SOCKET* Binding, _In_ const CXPLAT_ROUTE* Route, @@ -3120,8 +3120,6 @@ CxPlatSocketSend( // Callback still gets invoked on failure to do the cleanup. // } - - return STATUS_SUCCESS; } _IRQL_requires_max_(DISPATCH_LEVEL) diff --git a/src/platform/datapath_winuser.c b/src/platform/datapath_winuser.c index cb1d9f5619..bae93f68b9 100644 --- a/src/platform/datapath_winuser.c +++ b/src/platform/datapath_winuser.c @@ -4013,7 +4013,7 @@ CxPlatSendDataComplete( } _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void CxPlatSocketSendWithRio( _In_ CXPLAT_SEND_DATA* SendData, _In_ WSAMSG* WSAMhdr @@ -4066,18 +4066,16 @@ CxPlatSocketSendWithRio( WsaError, "RIOSendEx"); SendDataFree(SendData); - return HRESULT_FROM_WIN32(WsaError); + return; } SocketProc->RioSendCount++; CxPlatSocketArmRioNotify(SocketProc); } - - return QUIC_STATUS_SUCCESS; } _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void CxPlatSocketSendInline( _In_ const QUIC_ADDR* LocalAddress, _In_ CXPLAT_SEND_DATA* SendData @@ -4086,10 +4084,9 @@ CxPlatSocketSendInline( CXPLAT_SOCKET_PROC* SocketProc = SendData->SocketProc; if (SocketProc->RioSendCount == RIO_SEND_QUEUE_DEPTH) { CxPlatListInsertTail(&SocketProc->RioSendOverflow, &SendData->RioOverflowEntry); - return QUIC_STATUS_PENDING; + return; } - QUIC_STATUS Status; int Result; DWORD BytesSent; CXPLAT_DATAPATH* Datapath = SocketProc->Parent->Datapath; @@ -4174,7 +4171,8 @@ CxPlatSocketSendInline( } if (Socket->Type == CXPLAT_SOCKET_UDP && Socket->UseRio) { - return CxPlatSocketSendWithRio(SendData, &WSAMhdr); + CxPlatSocketSendWithRio(SendData, &WSAMhdr); + return; } // @@ -4208,11 +4206,8 @@ CxPlatSocketSendInline( if (Result == SOCKET_ERROR) { WsaError = WSAGetLastError(); if (WsaError == WSA_IO_PENDING) { - return QUIC_STATUS_SUCCESS; + return; } - Status = HRESULT_FROM_WIN32(WsaError); - } else { - Status = QUIC_STATUS_SUCCESS; } // @@ -4220,11 +4215,9 @@ CxPlatSocketSendInline( // CxPlatCancelDatapathIo(SocketProc, &SendData->Sqe); CxPlatSendDataComplete(SendData, WsaError); - - return Status; } -QUIC_STATUS +void CxPlatSocketSendEnqueue( _In_ const CXPLAT_ROUTE* Route, _In_ CXPLAT_SEND_DATA* SendData @@ -4237,11 +4230,10 @@ CxPlatSocketSendEnqueue( if (QUIC_FAILED(Status)) { CxPlatCancelDatapathIo(SendData->SocketProc, &SendData->Sqe); } - return Status; } _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void SocketSend( _In_ CXPLAT_SOCKET* Socket, _In_ const CXPLAT_ROUTE* Route, @@ -4265,21 +4257,18 @@ SocketSend( // // Currently RIO always queues sends. // - return CxPlatSocketSendEnqueue(Route, SendData); - } + CxPlatSocketSendEnqueue(Route, SendData); - if ((Socket->Type != CXPLAT_SOCKET_UDP) || + } else if ((Socket->Type != CXPLAT_SOCKET_UDP) || !(SendData->SendFlags & CXPLAT_SEND_FLAGS_MAX_THROUGHPUT)) { // // Currently TCP always sends inline. // - return - CxPlatSocketSendInline( - &Route->LocalAddress, - SendData); - } + CxPlatSocketSendInline(&Route->LocalAddress, SendData); - return CxPlatSocketSendEnqueue(Route, SendData); + } else { + CxPlatSocketSendEnqueue(Route, SendData); + } } void diff --git a/src/platform/datapath_xplat.c b/src/platform/datapath_xplat.c index 6d8c16485a..865513fbac 100644 --- a/src/platform/datapath_xplat.c +++ b/src/platform/datapath_xplat.c @@ -350,18 +350,19 @@ CxPlatSendDataIsFull( } _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void CxPlatSocketSend( _In_ CXPLAT_SOCKET* Socket, _In_ const CXPLAT_ROUTE* Route, _In_ CXPLAT_SEND_DATA* SendData ) { - CXPLAT_DBG_ASSERT( - DatapathType(SendData) == CXPLAT_DATAPATH_TYPE_USER || - DatapathType(SendData) == CXPLAT_DATAPATH_TYPE_RAW); - return DatapathType(SendData) == CXPLAT_DATAPATH_TYPE_USER ? - SocketSend(Socket, Route, SendData) : RawSocketSend(CxPlatSocketToRaw(Socket), Route, SendData); + if (DatapathType(SendData) == CXPLAT_DATAPATH_TYPE_USER) { + SocketSend(Socket, Route, SendData); + } else { + CXPLAT_DBG_ASSERT(DatapathType(SendData) == CXPLAT_DATAPATH_TYPE_RAW); + RawSocketSend(CxPlatSocketToRaw(Socket), Route, SendData); + } } _IRQL_requires_max_(PASSIVE_LEVEL) diff --git a/src/platform/pcp.c b/src/platform/pcp.c index 41307eed33..b91e6a54f6 100644 --- a/src/platform/pcp.c +++ b/src/platform/pcp.c @@ -424,14 +424,7 @@ CxPlatPcpSendMapRequestInternal( Request->MAP.SuggestedExternalIpAddress, sizeof(Request->MAP.SuggestedExternalIpAddress)); - QUIC_STATUS Status = - CxPlatSocketSend( - Socket, - &Route, - SendData); - if (QUIC_FAILED(Status)) { - return Status; - } + CxPlatSocketSend(Socket, &Route, SendData); return QUIC_STATUS_SUCCESS; } @@ -528,14 +521,7 @@ CxPlatPcpSendPeerRequestInternal( sizeof(Request->PEER.RemotePeerIpAddress)); Request->PEER.RemotePeerPort = RemotePeerMappedAddress.Ipv6.sin6_port; - QUIC_STATUS Status = - CxPlatSocketSend( - Socket, - &Route, - SendData); - if (QUIC_FAILED(Status)) { - return Status; - } + CxPlatSocketSend(Socket, &Route, SendData); return QUIC_STATUS_SUCCESS; } diff --git a/src/platform/platform_internal.h b/src/platform/platform_internal.h index 6e4675eb14..cbc71bf492 100644 --- a/src/platform/platform_internal.h +++ b/src/platform/platform_internal.h @@ -1076,7 +1076,7 @@ SendDataIsFull( ); _IRQL_requires_max_(DISPATCH_LEVEL) -QUIC_STATUS +void SocketSend( _In_ CXPLAT_SOCKET* Socket, _In_ const CXPLAT_ROUTE* Route, diff --git a/src/platform/unittest/DataPathTest.cpp b/src/platform/unittest/DataPathTest.cpp index 911d0f3fb4..e030c0ceb5 100644 --- a/src/platform/unittest/DataPathTest.cpp +++ b/src/platform/unittest/DataPathTest.cpp @@ -300,11 +300,7 @@ struct DataPathTest : public ::testing::TestWithParam ASSERT_NE(nullptr, ServerBuffer); memcpy(ServerBuffer->Buffer, RecvData->Buffer, RecvData->BufferLength); - VERIFY_QUIC_SUCCESS( - CxPlatSocketSend( - Socket, - RecvData->Route, - ServerSendData)); + CxPlatSocketSend(Socket, RecvData->Route, ServerSendData); } else if (RecvData->Route->RemoteAddress.Ipv4.sin_port == RecvContext->DestinationAddress.Ipv4.sin_port) { CxPlatEventSet(RecvContext->ClientCompletion); @@ -606,19 +602,15 @@ struct CxPlatSocket { QUIC_ADDR GetRemoteAddress() const noexcept { return Route.RemoteAddress; } - QUIC_STATUS + void Send( _In_ const CXPLAT_ROUTE& _Route, _In_ CXPLAT_SEND_DATA* SendData ) const noexcept { - return - CxPlatSocketSend( - Socket, - &_Route, - SendData); + CxPlatSocketSend(Socket, &_Route, SendData); } - QUIC_STATUS + void Send( _In_ const QUIC_ADDR& RemoteAddress, _In_ CXPLAT_SEND_DATA* SendData @@ -626,14 +618,14 @@ struct CxPlatSocket { { CXPLAT_ROUTE _Route = Route; _Route.RemoteAddress = RemoteAddress; - return Send(_Route, SendData); + Send(_Route, SendData); } - QUIC_STATUS + void Send( _In_ CXPLAT_SEND_DATA* SendData ) const noexcept { - return Send(Route, SendData); + Send(Route, SendData); } }; @@ -804,7 +796,7 @@ TEST_P(DataPathTest, UdpData) ASSERT_NE(nullptr, ClientBuffer); memcpy(ClientBuffer->Buffer, ExpectedData, ExpectedDataSize); - VERIFY_QUIC_SUCCESS(Client.Send(ClientSendData)); + Client.Send(ClientSendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(RecvContext.ClientCompletion, 2000)); } @@ -841,7 +833,7 @@ TEST_P(DataPathTest, UdpDataPolling) ASSERT_NE(nullptr, ClientBuffer); memcpy(ClientBuffer->Buffer, ExpectedData, ExpectedDataSize); - VERIFY_QUIC_SUCCESS(Client.Send(ClientSendData)); + Client.Send(ClientSendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(RecvContext.ClientCompletion, 2000)); } @@ -878,7 +870,7 @@ TEST_P(DataPathTest, UdpDataRebind) ASSERT_NE(nullptr, ClientBuffer); memcpy(ClientBuffer->Buffer, ExpectedData, ExpectedDataSize); - VERIFY_QUIC_SUCCESS(Client.Send(ClientSendData)); + Client.Send(ClientSendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(RecvContext.ClientCompletion, 2000)); CxPlatEventReset(RecvContext.ClientCompletion); } @@ -895,7 +887,7 @@ TEST_P(DataPathTest, UdpDataRebind) ASSERT_NE(nullptr, ClientBuffer); memcpy(ClientBuffer->Buffer, ExpectedData, ExpectedDataSize); - VERIFY_QUIC_SUCCESS(Client.Send(ClientSendData)); + Client.Send(ClientSendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(RecvContext.ClientCompletion, 2000)); } } @@ -933,7 +925,7 @@ TEST_P(DataPathTest, UdpDataECT0) ASSERT_NE(nullptr, ClientBuffer); memcpy(ClientBuffer->Buffer, ExpectedData, ExpectedDataSize); - VERIFY_QUIC_SUCCESS(Client.Send(ClientSendData)); + Client.Send(ClientSendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(RecvContext.ClientCompletion, 2000)); } @@ -982,7 +974,7 @@ TEST_P(DataPathTest, UdpShareClientSocket) memcpy(ClientBuffer->Buffer, ExpectedData, ExpectedDataSize); RecvContext.DestinationAddress = Server1.GetLocalAddress(); - VERIFY_QUIC_SUCCESS(Client1.Send(ClientSendData)); + Client1.Send(ClientSendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(RecvContext.ClientCompletion, 2000)); CxPlatEventReset(RecvContext.ClientCompletion); @@ -994,7 +986,7 @@ TEST_P(DataPathTest, UdpShareClientSocket) memcpy(ClientBuffer->Buffer, ExpectedData, ExpectedDataSize); RecvContext.DestinationAddress = Server2.GetLocalAddress(); - VERIFY_QUIC_SUCCESS(Client2.Send(ClientSendData)); + Client2.Send(ClientSendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(RecvContext.ClientCompletion, 2000)); CxPlatEventReset(RecvContext.ClientCompletion); } @@ -1145,7 +1137,7 @@ TEST_P(DataPathTest, TcpDataClient) ASSERT_NE(nullptr, SendBuffer); memcpy(SendBuffer->Buffer, ExpectedData, ExpectedDataSize); - VERIFY_QUIC_SUCCESS(Client.Send(SendData)); + Client.Send(SendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(ListenerContext.ServerContext.ReceiveEvent, 500)); } @@ -1190,11 +1182,7 @@ TEST_P(DataPathTest, TcpDataServer) ASSERT_NE(nullptr, SendBuffer); memcpy(SendBuffer->Buffer, ExpectedData, ExpectedDataSize); - VERIFY_QUIC_SUCCESS( - CxPlatSocketSend( - ListenerContext.Server, - &Route, - SendData)); + CxPlatSocketSend(ListenerContext.Server, &Route, SendData); ASSERT_TRUE(CxPlatEventWaitWithTimeout(ClientContext.ReceiveEvent, 500)); } diff --git a/src/test/lib/QuicDrill.cpp b/src/test/lib/QuicDrill.cpp index fea062a0b9..ab8fdaf867 100644 --- a/src/test/lib/QuicDrill.cpp +++ b/src/test/lib/QuicDrill.cpp @@ -191,7 +191,6 @@ struct DrillSender { _In_ const DrillBuffer& PacketBuffer ) { - QUIC_STATUS Status = QUIC_STATUS_SUCCESS; CXPLAT_FRE_ASSERT(PacketBuffer.size() <= UINT16_MAX); const uint16_t DatagramLength = (uint16_t) PacketBuffer.size(); @@ -208,8 +207,7 @@ struct DrillSender { if (SendBuffer == nullptr) { TEST_FAILURE("Buffer null"); - Status = QUIC_STATUS_OUT_OF_MEMORY; - return Status; + return QUIC_STATUS_OUT_OF_MEMORY; } // @@ -217,13 +215,12 @@ struct DrillSender { // memcpy(SendBuffer->Buffer, PacketBuffer.data(), DatagramLength); - Status = - CxPlatSocketSend( - Binding, - &Route, - SendData); + CxPlatSocketSend( + Binding, + &Route, + SendData); - return Status; + return QUIC_STATUS_SUCCESS; } }; diff --git a/src/tools/recvfuzz/recvfuzz.cpp b/src/tools/recvfuzz/recvfuzz.cpp index b31ab3179a..899c018946 100644 --- a/src/tools/recvfuzz/recvfuzz.cpp +++ b/src/tools/recvfuzz/recvfuzz.cpp @@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation. Licensed under the MIT License. -Abstract: +Abstract: Packet Fuzzer tool in the receive path. --*/ @@ -124,7 +124,7 @@ T GetRandom(T UpperBound) { _IRQL_requires_max_(DISPATCH_LEVEL) _Function_class_(CXPLAT_DATAPATH_RECEIVE_CALLBACK) void -UdpRecvCallback( +UdpRecvCallback( _In_ CXPLAT_SOCKET* Binding, _In_ void* Context, _In_ CXPLAT_RECV_DATA* RecvBufferChain @@ -171,11 +171,11 @@ UdpRecvCallback( Packet.KeyType = QuicPacketTypeToKeyTypeV1(Packet.LH->Type); } Packet.Encrypted = TRUE; - if (Packet.AvailBufferLength >= Packet.HeaderLength && - (memcmp(Packet.DestCid, &CurrSrcCid, sizeof(uint64_t)) == 0) && + if (Packet.AvailBufferLength >= Packet.HeaderLength && + (memcmp(Packet.DestCid, &CurrSrcCid, sizeof(uint64_t)) == 0) && (Packet.LH->Type == QUIC_INITIAL_V1 || Packet.LH->Type == QUIC_HANDSHAKE_V1)) { Packet.AvailBufferLength = Packet.HeaderLength + Packet.PayloadLength; - QUIC_RX_PACKET* PacketCopy = (QUIC_RX_PACKET *)CXPLAT_ALLOC_NONPAGED(sizeof(QUIC_RX_PACKET) + Packet.AvailBufferLength + Packet.DestCidLen + Packet.SourceCidLen, QUIC_POOL_TOOL); + QUIC_RX_PACKET* PacketCopy = (QUIC_RX_PACKET *)CXPLAT_ALLOC_NONPAGED(sizeof(QUIC_RX_PACKET) + Packet.AvailBufferLength + Packet.DestCidLen + Packet.SourceCidLen, QUIC_POOL_TOOL); memcpy(PacketCopy, &Packet, sizeof(QUIC_RX_PACKET)); PacketCopy->AvailBuffer = (uint8_t*)(PacketCopy + 1); memcpy((void *)PacketCopy->AvailBuffer, Packet.AvailBuffer, Packet.AvailBufferLength); @@ -184,7 +184,7 @@ UdpRecvCallback( PacketCopy->SourceCid = PacketCopy->DestCid + Packet.DestCidLen; memcpy((void *)PacketCopy->SourceCid, Packet.SourceCid, Packet.SourceCidLen); PacketQueue.push_back(PacketCopy); - } + } Packet.AvailBuffer += Packet.AvailBufferLength; } while (Packet.AvailBuffer - Datagram->Buffer < Datagram->BufferLength); Datagram = Datagram->Next; @@ -219,7 +219,7 @@ struct TlsContext State.Buffer = (uint8_t*)CXPLAT_ALLOC_NONPAGED(8000, QUIC_POOL_TOOL); State.BufferAllocLength = 8000; } - void CreateContext(uint64_t initSrcCid = MagicCid) { + void CreateContext(uint64_t initSrcCid = MagicCid) { uint8_t *stateBuffer = State.Buffer; CxPlatZeroMemory(&State, sizeof(State)); State.Buffer = stateBuffer; @@ -233,7 +233,7 @@ struct TlsContext OnRecvQuicTP, NULL }; - + if (QUIC_FAILED( CxPlatTlsSecConfigCreate( &CredConfig, @@ -417,7 +417,7 @@ bool WriteAckFrame( _Inout_ uint16_t* Offset, _In_ uint16_t BufferLength, _Out_writes_to_(BufferLength, *Offset) uint8_t* Buffer - ) + ) { QUIC_RANGE AckRange; QuicRangeInitialize(QUIC_MAX_RANGE_DECODE_ACKS, &AckRange); @@ -425,11 +425,11 @@ bool WriteAckFrame( QuicRangeAddRange(&AckRange, LargestAcknowledge, 1, &RangeUpdated); uint64_t AckDelay = 40; if (!QuicAckFrameEncode( - &AckRange, - AckDelay, - nullptr, - Offset, - BufferLength, + &AckRange, + AckDelay, + nullptr, + Offset, + BufferLength, Buffer)) { printf("QuicAckFrameEncode failure!\n"); return false; @@ -437,11 +437,11 @@ bool WriteAckFrame( return true; } -bool WriteCryptoFrame( +bool WriteCryptoFrame( _Inout_ uint16_t* Offset, _In_ uint16_t BufferLength, _Out_writes_to_(BufferLength, *Offset) - uint8_t* Buffer, + uint8_t* Buffer, _In_ TlsContext* ClientContext, _In_ PacketParams* PacketParams ) @@ -491,7 +491,7 @@ bool WriteCryptoFrame( return true; } -bool WriteClientPacket( +bool WriteClientPacket( _In_ uint32_t PacketNumber, _In_ uint16_t BufferLength, _Out_writes_to_(BufferLength, *PacketLength) @@ -509,9 +509,9 @@ bool WriteClientPacket( for (int i = 0; i < PacketParams->NumFrames; i++) { if (PacketParams->FrameTypes[i] == QUIC_FRAME_ACK) { if (!WriteAckFrame( - PacketParams->LargestAcknowledge, - &FrameBufferLength, - BufferSize, + PacketParams->LargestAcknowledge, + &FrameBufferLength, + BufferSize, FrameBuffer)) { return false; } @@ -519,10 +519,10 @@ bool WriteClientPacket( if (PacketParams->FrameTypes[i] == QUIC_FRAME_CRYPTO) { if (!WriteCryptoFrame( - &FrameBufferLength, - BufferSize, - FrameBuffer, - ClientContext, + &FrameBufferLength, + BufferSize, + FrameBuffer, + ClientContext, PacketParams)) { return false; } @@ -533,7 +533,7 @@ bool WriteClientPacket( QUIC_CID* DestCid = (QUIC_CID*)DestCidBuffer; QUIC_CID* SourceCid = (QUIC_CID*)SourceCidBuffer; - + DestCid->IsInitial = TRUE; DestCid->Length = PacketParams->DestCidLen; @@ -577,19 +577,19 @@ bool WriteClientPacket( void fuzzPacket(uint8_t* Packet, uint16_t PacketLength) { uint8_t numIteration = (uint8_t)GetRandom(256); for(int i = 0; i < numIteration; i++){ - Packet[GetRandom(PacketLength)] = (uint8_t)GetRandom(256); + Packet[GetRandom(PacketLength)] = (uint8_t)GetRandom(256); } } void sendPacket( - CXPLAT_SOCKET* Binding, - CXPLAT_ROUTE Route, - int64_t* PacketCount, - int64_t* TotalByteCount, - PacketParams* PacketParams, - bool fuzzing = true, + CXPLAT_SOCKET* Binding, + CXPLAT_ROUTE Route, + int64_t* PacketCount, + int64_t* TotalByteCount, + PacketParams* PacketParams, + bool fuzzing = true, TlsContext* ClientContext = nullptr) { - const uint16_t DatagramLength = QUIC_MIN_INITIAL_LENGTH; + const uint16_t DatagramLength = QUIC_MIN_INITIAL_LENGTH; CXPLAT_SEND_CONFIG SendConfig = { &Route, DatagramLength, CXPLAT_ECN_NON_ECT, 0 }; CXPLAT_SEND_DATA* SendData = CxPlatSendDataAlloc(Binding, &SendConfig); if (!SendData) { @@ -612,7 +612,7 @@ void sendPacket( CxPlatSendDataFree(SendData); return; } - + uint16_t PacketNumberOffset = HeaderLength - sizeof(uint32_t); uint8_t* DestCid = (uint8_t*)(Packet + sizeof(QUIC_LONG_HEADER_V1)); @@ -622,9 +622,9 @@ void sendPacket( } else { memcpy(DestCid, PacketParams->DestCid, PacketParams->DestCidLen); } - + memcpy(SrcCid, PacketParams->SourceCid, PacketParams->SourceCidLen); - + if (fuzzing) { fuzzPacket(Packet, sizeof(Packet)); } @@ -703,15 +703,8 @@ void sendPacket( break; } } - - if (QUIC_FAILED( - CxPlatSocketSend( - Binding, - &Route, - SendData))) { - printf("Send failed!\n"); - exit(0); - } + + CxPlatSocketSend(Binding, &Route, SendData); } void fuzz(CXPLAT_SOCKET* Binding, CXPLAT_ROUTE Route) { @@ -765,11 +758,11 @@ void fuzz(CXPLAT_SOCKET* Binding, CXPLAT_ROUTE Route) { while (!PacketQueue.empty()) { QUIC_RX_PACKET* packet = PacketQueue.front(); - if (!packet->DestCidLen || - !packet->DestCid || packet->PayloadLength < 4 + CXPLAT_HP_SAMPLE_LENGTH || + if (!packet->DestCidLen || + !packet->DestCid || packet->PayloadLength < 4 + CXPLAT_HP_SAMPLE_LENGTH || (memcmp(packet->DestCid, &CurrSrcCid, sizeof(uint64_t)) != 0)) { CXPLAT_FREE(packet, QUIC_POOL_TOOL); - PacketQueue.pop_front(); + PacketQueue.pop_front(); continue; } if (packet->LH->Type == QUIC_INITIAL_V1) { @@ -782,11 +775,11 @@ void fuzz(CXPLAT_SOCKET* Binding, CXPLAT_ROUTE Route) { packet->AvailBuffer + packet->HeaderLength + 4, CXPLAT_HP_SAMPLE_LENGTH); // same step for all long header packets - - QUIC_PACKET_KEY_TYPE KeyType = packet->KeyType; + + QUIC_PACKET_KEY_TYPE KeyType = packet->KeyType; if (HandshakeClientContext.State.ReadKeys[KeyType] == nullptr) { CXPLAT_FREE(packet, QUIC_POOL_TOOL); - PacketQueue.pop_front(); + PacketQueue.pop_front(); continue; } if (QUIC_FAILED( @@ -798,7 +791,7 @@ void fuzz(CXPLAT_SOCKET* Binding, CXPLAT_ROUTE Route) { printf("Failed to Compute Mask\n"); } uint8_t CompressedPacketNumberLength = 0; - ((uint8_t*)packet->AvailBuffer)[0] ^= HpMask[0] & 0x0F; + ((uint8_t*)packet->AvailBuffer)[0] ^= HpMask[0] & 0x0F; CompressedPacketNumberLength = packet->LH->PnLength + 1; for (uint8_t i = 0; i < CompressedPacketNumberLength; i++) { ((uint8_t*)packet->AvailBuffer)[packet->HeaderLength + i] ^= HpMask[1 + i]; @@ -834,11 +827,11 @@ void fuzz(CXPLAT_SOCKET* Binding, CXPLAT_ROUTE Route) { (uint8_t*)Payload))) { // Buffer printf("CxPlatDecrypt failed\n"); CXPLAT_FREE(packet, QUIC_POOL_TOOL); - PacketQueue.pop_front(); + PacketQueue.pop_front(); continue; } packet->PayloadLength -= CXPLAT_ENCRYPTION_OVERHEAD; - + QUIC_VAR_INT FrameType INIT_NO_SAL(0); uint16_t offset = 0; uint16_t PayloadLength = packet->PayloadLength; @@ -904,12 +897,12 @@ void fuzz(CXPLAT_SOCKET* Binding, CXPLAT_ROUTE Route) { sendPacket(Binding, Route, &HandshakePacketCount, &TotalByteCount, &HandshakePacketParams, true, &HandshakeClientContext); handshakeComplete = FALSE; CXPLAT_FREE(packet, QUIC_POOL_TOOL); - PacketQueue.pop_front(); + PacketQueue.pop_front(); break; } CXPLAT_FREE(packet, QUIC_POOL_TOOL); - PacketQueue.pop_front(); - } + PacketQueue.pop_front(); + } for (uint8_t i = 0; i < QUIC_PACKET_KEY_COUNT; ++i) { if (HandshakeClientContext.State.ReadKeys[i] != nullptr) { @@ -920,12 +913,12 @@ void fuzz(CXPLAT_SOCKET* Binding, CXPLAT_ROUTE Route) { QuicPacketKeyFree(HandshakeClientContext.State.WriteKeys[i]); HandshakeClientContext.State.WriteKeys[i] = nullptr; } - } + } } } while (!PacketQueue.empty()) { QUIC_RX_PACKET* packet = PacketQueue.front(); - CXPLAT_FREE(packet, QUIC_POOL_TOOL); + CXPLAT_FREE(packet, QUIC_POOL_TOOL); PacketQueue.pop_front(); } printf("Total Initial Packets sent: %lld\n", (long long)InitialPacketCount); @@ -1018,7 +1011,7 @@ main(int argc, char **argv) { uint32_t RngSeed = 0; if (!TryGetValue(argc, argv, "seed", &RngSeed)) { CxPlatRandom(sizeof(RngSeed), &RngSeed); - } + } printf("Using seed value: %u\n", RngSeed); srand(RngSeed); start();