diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs
index 9ce4f53d0..a87bf047b 100644
--- a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs
+++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs
@@ -138,23 +138,28 @@ public override int Encrypt(byte[] input, int offset, int length, byte[] output,
/// The decrypted plaintext.
public override byte[] Decrypt(byte[] input, int offset, int length)
{
- byte[] output;
+ var output = new byte[length];
+
+ _cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv));
+
+ var keyStream = new byte[64];
+ _cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0);
+ _mac.Init(new KeyParameter(keyStream, 0, 32));
if (_aadLength > 0)
{
// If we are in 'AAD mode', then put these bytes through the AAD cipher.
+ _mac.BlockUpdate(input, offset, length);
+
Debug.Assert(_aadCipher != null);
_aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv));
- output = new byte[length];
_aadCipher.ProcessBytes(input, offset, length, output, 0);
}
else
{
- output = new byte[length];
-
var bytesWritten = Decrypt(input, offset, length, output, 0);
Debug.Assert(bytesWritten == length);
@@ -169,7 +174,7 @@ public override byte[] Decrypt(byte[] input, int offset, int length)
///
/// The input data with below format:
///
- /// [----][----Cipher AAD----(offset)][----Cipher Text----(length)][----TAG----]
+ /// [----(offset)][----Cipher Text----(length)][----TAG----]
///
///
/// The zero-based offset in at which to begin decrypting and authenticating.
@@ -179,16 +184,8 @@ public override byte[] Decrypt(byte[] input, int offset, int length)
/// The number of plaintext bytes written to .
public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset)
{
- Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length");
-
- _cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv));
-
- var keyStream = new byte[64];
- _cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0);
- _mac.Init(new KeyParameter(keyStream, 0, 32));
-
var tag = new byte[TagSize];
- _mac.BlockUpdate(input, offset - _aadLength, length + _aadLength);
+ _mac.BlockUpdate(input, offset, length);
_ = _mac.DoFinal(tag, 0);
if (!Arrays.FixedTimeEquals(TagSize, tag, 0, input, offset + length))
{
diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs
index a6576d9af..c73c3db26 100644
--- a/src/Renci.SshNet/Session.cs
+++ b/src/Renci.SshNet/Session.cs
@@ -105,6 +105,23 @@ public sealed class Session : ISession
///
private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1);
+ private readonly byte[] _inboundPacketSequenceBytes = new byte[4];
+
+ ///
+ /// Gets or sets the incoming packet number.
+ ///
+ private uint InboundPacketSequence
+ {
+ get
+ {
+ return BinaryPrimitives.ReadUInt32BigEndian(_inboundPacketSequenceBytes);
+ }
+ set
+ {
+ BinaryPrimitives.WriteUInt32BigEndian(_inboundPacketSequenceBytes, value);
+ }
+ }
+
///
/// Holds metadata about session messages.
///
@@ -120,11 +137,6 @@ public sealed class Session : ISession
///
private volatile uint _outboundPacketSequence;
- ///
- /// Specifies incoming packet number.
- ///
- private uint _inboundPacketSequence;
-
///
/// WaitHandle to signal that last service request was accepted.
///
@@ -200,7 +212,6 @@ public sealed class Session : ISession
private Socket _socket;
private ArrayBuffer _receiveBuffer = new(4 * 1024);
- private byte[] _plaintextReceiveBuffer = new byte[4 * 1024];
///
/// Gets the session semaphore that controls session channels.
@@ -1213,9 +1224,6 @@ private bool TrySendMessage(Message message)
///
private Message ReceiveMessage(Socket socket)
{
- // the length of the packet sequence field in bytes
- const int inboundPacketSequenceLength = 4;
-
// The length of the "packet length" field in bytes
const int packetLengthFieldLength = 4;
@@ -1272,31 +1280,28 @@ private Message ReceiveMessage(Socket socket)
}
}
- var firstBlock = new ArraySegment(
- _receiveBuffer.DangerousGetUnderlyingBuffer(),
- _receiveBuffer.ActiveStartOffset,
- blockSize);
-
- var plainFirstBlock = firstBlock;
-
- // For ETM or AES-GCM, firstBlock holds the packet length which is
- // not encrypted. Otherwise, we decrypt the first "blockSize" bytes.
- // (For chacha20-poly1305, this means passing the encrypted packet
- // length as AAD).
+ // For ETM or AES-GCM, the first "blockSize" bytes hold the packet length
+ // which is not encrypted. Otherwise, we decrypt them.
+ // (For chacha20-poly1305, this means passing the encrypted packet length
+ // to its AAD cipher instance - it is the awkward difference between the
+ // 3-arg and 5-arg Decrypt, and explains why we don't just decrypt these
+ // bytes in-place).
if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher)
{
- _serverCipher.SetSequenceNumber(_inboundPacketSequence);
+ _serverCipher.SetSequenceNumber(InboundPacketSequence);
if (_serverMac == null || !_serverEtm)
{
- plainFirstBlock = new ArraySegment(_serverCipher.Decrypt(
- firstBlock.Array,
- firstBlock.Offset,
- firstBlock.Count));
+ var plainFirstBlock = _serverCipher.Decrypt(
+ _receiveBuffer.DangerousGetUnderlyingBuffer(),
+ _receiveBuffer.ActiveStartOffset,
+ blockSize);
+
+ plainFirstBlock.CopyTo(_receiveBuffer.ActiveSpan);
}
}
- var packetLength = BinaryPrimitives.ReadInt32BigEndian(plainFirstBlock);
+ var packetLength = BinaryPrimitives.ReadInt32BigEndian(_receiveBuffer.ActiveReadOnlySpan);
// Test packet minimum and maximum boundaries
if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
@@ -1330,26 +1335,13 @@ private Message ReceiveMessage(Socket socket)
}
}
- // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
- // to generate the hash.
- var plaintextLength = 4 + totalPacketLength - serverMacLength;
-
- if (_plaintextReceiveBuffer.Length < plaintextLength)
- {
- Array.Resize(ref _plaintextReceiveBuffer, Math.Max(plaintextLength, 2 * _plaintextReceiveBuffer.Length));
- }
-
- BinaryPrimitives.WriteUInt32BigEndian(_plaintextReceiveBuffer, _inboundPacketSequence);
-
- plainFirstBlock.AsSpan().CopyTo(_plaintextReceiveBuffer.AsSpan(4));
-
if (_serverMac != null && _serverEtm)
{
// ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet)
// sequence_number
_ = _serverMac.TransformBlock(
- inputBuffer: _plaintextReceiveBuffer,
+ inputBuffer: _inboundPacketSequenceBytes,
inputOffset: 0,
inputCount: 4,
outputBuffer: null,
@@ -1377,41 +1369,52 @@ private Message ReceiveMessage(Socket socket)
{
Debug.Assert(numberOfBytesToDecrypt % blockSize == 0);
+ var decryptBuffer = _receiveBuffer.DangerousGetUnderlyingBuffer();
+ var decryptOffset = _receiveBuffer.ActiveStartOffset + blockSize;
+
var numberOfBytesDecrypted = _serverCipher.Decrypt(
- input: _receiveBuffer.DangerousGetUnderlyingBuffer(),
- offset: _receiveBuffer.ActiveStartOffset + blockSize,
+ input: decryptBuffer,
+ offset: decryptOffset,
length: numberOfBytesToDecrypt,
- output: _plaintextReceiveBuffer,
- outputOffset: 4 + blockSize);
+ output: decryptBuffer,
+ outputOffset: decryptOffset);
Debug.Assert(numberOfBytesDecrypted == numberOfBytesToDecrypt);
}
- else
- {
- _receiveBuffer.ActiveReadOnlySpan
- .Slice(blockSize, numberOfBytesToDecrypt)
- .CopyTo(_plaintextReceiveBuffer.AsSpan(4 + blockSize));
- }
if (_serverMac != null && !_serverEtm)
{
// non-ETM mac = MAC(key, sequence_number || unencrypted_packet)
- var clientHash = _serverMac.ComputeHash(_plaintextReceiveBuffer, 0, plaintextLength);
+ // sequence_number
+ _ = _serverMac.TransformBlock(
+ inputBuffer: _inboundPacketSequenceBytes,
+ inputOffset: 0,
+ inputCount: 4,
+ outputBuffer: null,
+ outputOffset: 0);
+
+ // unencrypted_packet
+ _ = _serverMac.TransformBlock(
+ inputBuffer: _receiveBuffer.DangerousGetUnderlyingBuffer(),
+ inputOffset: _receiveBuffer.ActiveStartOffset,
+ inputCount: totalPacketLength - serverMacLength,
+ outputBuffer: null,
+ outputOffset: 0);
+
+ _ = _serverMac.TransformFinalBlock(Array.Empty(), 0, 0);
- if (!CryptoAbstraction.FixedTimeEquals(clientHash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength)))
+ if (!CryptoAbstraction.FixedTimeEquals(_serverMac.Hash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength)))
{
throw new SshConnectionException("MAC error", DisconnectReason.MacError);
}
}
- _receiveBuffer.Discard(totalPacketLength);
-
- var paddingLength = _plaintextReceiveBuffer[inboundPacketSequenceLength + packetLengthFieldLength];
+ var paddingLength = _receiveBuffer.ActiveReadOnlySpan[packetLengthFieldLength];
ArraySegment payload = new(
- _plaintextReceiveBuffer,
- offset: inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength,
+ _receiveBuffer.DangerousGetUnderlyingBuffer(),
+ offset: _receiveBuffer.ActiveStartOffset + packetLengthFieldLength + paddingLengthFieldLength,
count: packetLength - paddingLength - paddingLengthFieldLength);
if (_serverDecompression != null)
@@ -1419,16 +1422,24 @@ private Message ReceiveMessage(Socket socket)
payload = new(_serverDecompression.Decompress(payload.Array, payload.Offset, payload.Count));
}
- _inboundPacketSequence++;
+ var newInboundPacketSequence = ++InboundPacketSequence;
// The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
// It ensures the integrity of key exchange process.
- if (_inboundPacketSequence == uint.MaxValue && _isInitialKex)
+ if (newInboundPacketSequence == uint.MaxValue && _isInitialKex)
{
throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed);
}
- return LoadMessage(payload.Array, payload.Offset, payload.Count);
+ var message = LoadMessage(payload.Array, payload.Offset, payload.Count);
+
+ // The deserialised message may still reference data in the buffer, so calling Discard
+ // here might seem misguided. It is OK because Discard does not mutate the buffer
+ // and it will not be touched again until the next call to ReceiveMessage, which will
+ // only occur after the message has been fully processed.
+ _receiveBuffer.Discard(totalPacketLength);
+
+ return message;
}
private void TrySendDisconnect(DisconnectReason reasonCode, string message)
@@ -1545,7 +1556,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
_logger.LogDebug("[{SessionId}] Enabling strict key exchange extension.", SessionIdHex);
- if (_inboundPacketSequence != 1)
+ if (InboundPacketSequence != 1)
{
throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed);
}
@@ -1646,7 +1657,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)
if (_isStrictKex)
{
- _inboundPacketSequence = 0;
+ InboundPacketSequence = 0;
}
NewKeysReceived?.Invoke(this, new MessageEventArgs(message));