Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make IncomingCommand a class and simplify code around it #1628

Merged
merged 2 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions projects/RabbitMQ.Client/client/RentedMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ namespace RabbitMQ.Client
{
internal struct RentedMemory : IDisposable
{
private bool _disposedValue;

internal RentedMemory(byte[] rentedArray)
: this(new ReadOnlyMemory<byte>(rentedArray), rentedArray)
{
Expand All @@ -47,21 +45,20 @@ internal RentedMemory(ReadOnlyMemory<byte> memory, byte[] rentedArray)
{
Memory = memory;
RentedArray = rentedArray;
_disposedValue = false;
}

internal readonly ReadOnlyMemory<byte> Memory;

internal readonly byte[] ToArray()
{
return Memory.ToArray();
}
internal ReadOnlyMemory<byte> Memory;

internal readonly int Size => Memory.Length;

internal readonly ReadOnlySpan<byte> Span => Memory.Span;

internal readonly byte[] RentedArray;
internal byte[] RentedArray;

internal readonly byte[] ToArray()
{
return Memory.ToArray();
}

internal readonly ReadOnlyMemory<byte> CopyToMemory()
{
Expand All @@ -70,14 +67,11 @@ internal readonly ReadOnlyMemory<byte> CopyToMemory()

public void Dispose()
{
if (!_disposedValue)
if (RentedArray != null)
{
if (RentedArray != null)
{
ArrayPool<byte>.Shared.Return(RentedArray);
}

_disposedValue = true;
ArrayPool<byte>.Shared.Return(RentedArray);
RentedArray = default;
Memory = default;
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions projects/RabbitMQ.Client/client/framing/Channel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,17 @@ protected override Task<bool> DispatchCommandAsync(IncomingCommand cmd, Cancella
}
case ProtocolCommandId.BasicAck:
{
HandleBasicAck(in cmd);
HandleBasicAck(cmd);
return Task.FromResult(true);
}
case ProtocolCommandId.BasicNack:
{
HandleBasicNack(in cmd);
HandleBasicNack(cmd);
return Task.FromResult(true);
}
case ProtocolCommandId.BasicReturn:
{
HandleBasicReturn(in cmd);
HandleBasicReturn(cmd);
return Task.FromResult(true);
}
case ProtocolCommandId.ChannelClose:
Expand All @@ -118,7 +118,7 @@ protected override Task<bool> DispatchCommandAsync(IncomingCommand cmd, Cancella
}
case ProtocolCommandId.ConnectionBlocked:
{
HandleConnectionBlocked(in cmd);
HandleConnectionBlocked(cmd);
return Task.FromResult(true);
}
case ProtocolCommandId.ConnectionClose:
Expand All @@ -143,7 +143,7 @@ protected override Task<bool> DispatchCommandAsync(IncomingCommand cmd, Cancella
}
case ProtocolCommandId.ConnectionUnblocked:
{
HandleConnectionUnblocked(in cmd);
HandleConnectionUnblocked(cmd);
return Task.FromResult(true);
}
default:
Expand Down
30 changes: 15 additions & 15 deletions projects/RabbitMQ.Client/client/impl/ChannelBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ internal abstract class ChannelBase : IChannel, IRecoverable
internal TaskCompletionSource<ConnectionStartDetails> m_connectionStartCell;
private Exception m_connectionStartException = null;

// AMQP only allows one RPC operation to be active at a time.
// AMQP only allows one RPC operation to be active at a time.
protected readonly SemaphoreSlim _rpcSemaphore = new SemaphoreSlim(1, 1);
private readonly RpcContinuationQueue _continuationQueue = new RpcContinuationQueue();
private readonly ManualResetEventSlim _flowControlBlock = new ManualResetEventSlim(true);
Expand Down Expand Up @@ -425,8 +425,8 @@ private async Task HandleCommandAsync(IncomingCommand cmd, CancellationToken can
/*
* If DispatchCommandAsync returns `true`, it means that the incoming command is server-originated, and has
* already been handled.
*
* Else, the incoming command is the return of an RPC call, and must be handled.
*
* Else, the incoming command is the return of an RPC call, and must be handled.
*/
if (false == await DispatchCommandAsync(cmd, cancellationToken)
.ConfigureAwait(false))
Expand Down Expand Up @@ -561,7 +561,7 @@ public Task ConnectionTuneOkAsync(ushort channelMax, uint frameMax, ushort heart
return ModelSendAsync(method, cancellationToken).AsTask();
}

protected void HandleBasicAck(in IncomingCommand cmd)
protected void HandleBasicAck(IncomingCommand cmd)
{
try
{
Expand All @@ -580,7 +580,7 @@ protected void HandleBasicAck(in IncomingCommand cmd)
}
}

protected void HandleBasicNack(in IncomingCommand cmd)
protected void HandleBasicNack(IncomingCommand cmd)
{
try
{
Expand Down Expand Up @@ -679,17 +679,17 @@ await ConsumerDispatcher.HandleBasicDeliverAsync(
method._exchange,
method._routingKey,
header,
cmd.Body,
/*
* Takeover Body so it doesn't get returned as it is necessary
* for handling the Basic.Deliver method by client code.
*/
cmd.TakeoverBody(),
cancellationToken).ConfigureAwait(false);
return true;
}
finally
{
/*
* Note: do not return the Body as it is necessary for handling
* the Basic.Deliver method by client code
*/
cmd.ReturnMethodAndHeaderBuffers();
cmd.ReturnBuffers();
}
}

Expand All @@ -698,7 +698,7 @@ protected virtual ulong AdjustDeliveryTag(ulong deliveryTag)
return deliveryTag;
}

protected void HandleBasicReturn(in IncomingCommand cmd)
protected void HandleBasicReturn(IncomingCommand cmd)
{
try
{
Expand Down Expand Up @@ -800,7 +800,7 @@ await ModelSendAsync(method, cancellationToken).
}
}

protected void HandleConnectionBlocked(in IncomingCommand cmd)
protected void HandleConnectionBlocked(IncomingCommand cmd)
{
try
{
Expand Down Expand Up @@ -851,7 +851,7 @@ await ModelSendAsync(replyMethod, cancellationToken)
protected async Task<bool> HandleConnectionSecureAsync(IncomingCommand _)
{
var k = (ConnectionSecureOrTuneAsyncRpcContinuation)_continuationQueue.Next();
await k.HandleCommandAsync(IncomingCommand.Empty)
await k.HandleCommandAsync(new IncomingCommand())
.ConfigureAwait(false); // release the continuation.
return true;
}
Expand Down Expand Up @@ -903,7 +903,7 @@ await k.HandleCommandAsync(cmd)
return true;
}

protected void HandleConnectionUnblocked(in IncomingCommand cmd)
protected void HandleConnectionUnblocked(IncomingCommand cmd)
{
try
{
Expand Down
52 changes: 19 additions & 33 deletions projects/RabbitMQ.Client/client/impl/CommandAssembler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,19 @@ internal sealed class CommandAssembler
{
private const int MaxArrayOfBytesSize = 2_147_483_591;

private ProtocolCommandId _commandId;
private RentedMemory _methodMemory;
private RentedMemory _headerMemory;
private RentedMemory _bodyMemory;
private readonly IncomingCommand _currentCommand;
private readonly uint _maxBodyLength;

private int _remainingBodyByteCount;
private int _offset;
private AssemblyState _state;

private readonly uint _maxBodyLength;

public CommandAssembler(uint maxBodyLength)
{
_currentCommand = new IncomingCommand();
_maxBodyLength = maxBodyLength;
Reset();
}

private void Reset()
{
_commandId = default;
_methodMemory = default;
_headerMemory = default;
_bodyMemory = default;
_remainingBodyByteCount = 0;
_offset = 0;
_state = AssemblyState.ExpectingMethod;
}

public void HandleFrame(InboundFrame frame, out IncomingCommand command)
public IncomingCommand? HandleFrame(InboundFrame frame)
{
switch (_state)
{
Expand All @@ -88,13 +73,14 @@ public void HandleFrame(InboundFrame frame, out IncomingCommand command)

if (_state != AssemblyState.Complete)
{
command = IncomingCommand.Empty;
return;
return default;
}

RabbitMqClientEventSource.Log.CommandReceived();
command = new IncomingCommand(_commandId, _methodMemory, _headerMemory, _bodyMemory);
Reset();
_remainingBodyByteCount = 0;
_state = AssemblyState.ExpectingMethod;

return _currentCommand;
}

private void ParseMethodFrame(InboundFrame frame)
Expand All @@ -104,10 +90,10 @@ private void ParseMethodFrame(InboundFrame frame)
throw new UnexpectedFrameException(frame.Type);
}

_commandId = (ProtocolCommandId)NetworkOrderDeserializer.ReadUInt32(frame.Payload.Span);
_methodMemory = frame.TakeoverPayload(Framing.Method.ArgumentsOffset);
_currentCommand.CommandId = (ProtocolCommandId)NetworkOrderDeserializer.ReadUInt32(frame.Payload.Span);
_currentCommand.Method = frame.TakeoverPayload(Framing.Method.ArgumentsOffset);

switch (_commandId)
switch (_currentCommand.CommandId)
{
// Commands with payload
case ProtocolCommandId.BasicGetOk:
Expand Down Expand Up @@ -154,7 +140,7 @@ private void ParseHeaderFrame(InboundFrame frame)
}
else
{
_headerMemory = frame.TakeoverPayload(Framing.Header.HeaderArgumentOffset);
_currentCommand.Header = frame.TakeoverPayload(Framing.Header.HeaderArgumentOffset);
}

_remainingBodyByteCount = (int)totalBodyBytes;
Expand All @@ -174,25 +160,25 @@ private void ParseBodyFrame(InboundFrame frame)
throw new MalformedFrameException($"Overlong content body received - {_remainingBodyByteCount} bytes remaining, {payloadLength} bytes received");
}

if (_bodyMemory.RentedArray is null)
if (_currentCommand.Body.RentedArray is null)
{
// check for single frame payload for an early exit
if (payloadLength == _remainingBodyByteCount)
{
_bodyMemory = frame.TakeoverPayload(0);
_currentCommand.Body = frame.TakeoverPayload(0);
_state = AssemblyState.Complete;
return;
}

// Is returned by IncomingCommand.ReturnPayload in Session.HandleFrame
var rentedBodyArray = ArrayPool<byte>.Shared.Rent(_remainingBodyByteCount);
_bodyMemory = new RentedMemory(new ReadOnlyMemory<byte>(rentedBodyArray, 0, _remainingBodyByteCount), rentedBodyArray);
_currentCommand.Body.RentedArray = rentedBodyArray;
_currentCommand.Body.Memory = new ReadOnlyMemory<byte>(rentedBodyArray, 0, _remainingBodyByteCount);
}

frame.Payload.Span.CopyTo(_bodyMemory.RentedArray.AsSpan(_offset));
frame.Payload.Span.CopyTo(_currentCommand.Body.RentedArray.AsSpan(_currentCommand.Body.Memory.Length - _remainingBodyByteCount));
frame.TryReturnPayload();
_remainingBodyByteCount -= payloadLength;
_offset += payloadLength;
UpdateContentBodyState();
}

Expand Down
57 changes: 13 additions & 44 deletions projects/RabbitMQ.Client/client/impl/IncomingCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,60 +34,29 @@

namespace RabbitMQ.Client.Impl
{
internal readonly struct IncomingCommand
internal sealed class IncomingCommand
{
public static readonly IncomingCommand Empty = default;
public ProtocolCommandId CommandId;

public readonly ProtocolCommandId CommandId;
public RentedMemory Method;
public RentedMemory Header;
public RentedMemory Body;

public readonly RentedMemory Method;
public readonly RentedMemory Header;
public readonly RentedMemory Body;
public ReadOnlySpan<byte> MethodSpan => Method.Memory.Span;
public ReadOnlySpan<byte> HeaderSpan => Header.Memory.Span;
public ReadOnlySpan<byte> BodySpan => Body.Memory.Span;

public readonly bool IsEmpty => CommandId is default(ProtocolCommandId);

public IncomingCommand(ProtocolCommandId commandId,
RentedMemory method, RentedMemory header, RentedMemory body)
{
CommandId = commandId;
Method = method;
Header = header;
Body = body;
}

public ReadOnlySpan<byte> MethodSpan
public RentedMemory TakeoverBody()
{
get
{
return Method.Memory.Span;
}
RentedMemory body = Body;
Body = default;
return body;
}

public ReadOnlySpan<byte> HeaderSpan
{
get
{
return Header.Memory.Span;
}
}

public ReadOnlySpan<byte> BodySpan
{
get
{
return Body.Memory.Span;
}
}

public void ReturnMethodAndHeaderBuffers()
public void ReturnBuffers()
{
Method.Dispose();
Header.Dispose();
}

public void ReturnBuffers()
{
ReturnMethodAndHeaderBuffers();
Body.Dispose();
}
}
Expand Down
5 changes: 2 additions & 3 deletions projects/RabbitMQ.Client/client/impl/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ public Session(Connection connection, ushort channelNumber, uint maxBodyLength)

public override Task HandleFrameAsync(InboundFrame frame, CancellationToken cancellationToken)
{
_assembler.HandleFrame(frame, out IncomingCommand cmd);

if (cmd.IsEmpty)
IncomingCommand cmd = _assembler.HandleFrame(frame);
if (cmd is null)
{
return Task.CompletedTask;
}
Expand Down