using System; using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace Teknik.Utilities.Cryptography { public class AesCounterStream : Stream { private Stream _Inner; private CounterModeCryptoTransform _Cipher; /// /// Performs Encryption or Decryption on a stream with the given Key and IV /// /// Cipher is AES-256 in CTR mode with no padding /// /// /// /// /// public AesCounterStream(Stream stream, bool encrypt, PooledArray key, PooledArray iv) { _Inner = stream; var keyBytes = key.ToArray(); var ivBytes = iv.ToArray(); // Create the Aes Cipher using AesCounterMode aes = new AesCounterMode(iv); if (encrypt) { _Cipher = (CounterModeCryptoTransform)aes.CreateEncryptor(keyBytes, ivBytes); // Encrypt } else { _Cipher = (CounterModeCryptoTransform)aes.CreateDecryptor(keyBytes, ivBytes); // Decrypt } // Sync the counter SyncCounter(); } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { if (_Inner != null && CanRead) { int processed = 0; // Read the data from the stream int bytesRead = await _Inner.ReadAsync(buffer).ConfigureAwait(false); if (bytesRead > 0) { // Process the read buffer processed = _Cipher.TransformBlock(buffer.Span, 0, bytesRead); } // Do we have more? if (processed < bytesRead) { // Finalize the cipher var finalProcessed = _Cipher.TransformFinalBlock(buffer.Span, processed, bytesRead); if (finalProcessed > 0) processed += finalProcessed; } return processed; } return -1; } public override int Read(byte[] buffer, int offset, int count) { if (_Inner != null && CanRead) { Span readBuf = buffer; int processed = 0; // Read the data from the stream int bytesRead = _Inner.Read(readBuf.Slice(offset, count)); if (bytesRead > 0) { // Process the read buffer processed = _Cipher.TransformBlock(readBuf, offset, bytesRead); } // Do we have more? if (processed < bytesRead) { // Finalize the cipher var finalProcessed = _Cipher.TransformFinalBlock(readBuf, processed + offset, bytesRead); if (finalProcessed > 0) processed += finalProcessed; } return processed; } return -1; } public override int Read(Span buffer) { if (_Inner != null && CanRead) { int processed = 0; // Read the data from the stream int bytesRead = _Inner.Read(buffer); if (bytesRead > 0) { // Process the read buffer processed = _Cipher.TransformBlock(buffer, 0, bytesRead); } // Do we have more? if (processed < bytesRead) { // Finalize the cipher var finalProcessed = _Cipher.TransformFinalBlock(buffer, processed, bytesRead); if (finalProcessed > 0) processed += finalProcessed; } return processed; } return -1; } public override void Write(byte[] buffer, int offset, int count) { if (_Inner != null && CanWrite) { // Process the cipher Span output = buffer; // Process the buffer int processed = _Cipher.TransformBlock(output, offset, count); // Do we have more? if (processed < count) { // Finalize the cipher var finalProcessed = _Cipher.TransformFinalBlock(output, processed + offset, count); if (finalProcessed > 0) processed += finalProcessed; } _Inner.Write(output); } } public override bool CanRead { get { if (_Inner != null) { return _Inner.CanRead; } return false; } } public override bool CanSeek { get { if (_Inner != null) { return _Inner.CanSeek; } return false; } } public override bool CanWrite { get { if (_Inner != null) { return _Inner.CanWrite; } return false; } } public override long Length { get { if (_Inner != null) { return _Inner.Length; } return -1; } } public override long Position { get { if (_Inner != null) { return _Inner.Position; } return -1; } set { if (_Inner != null) { _Inner.Position = value; // Sync the counter SyncCounter(); } } } public override void Flush() { if (_Inner != null) { _Inner.Flush(); } } public override long Seek(long offset, SeekOrigin origin) { if (_Inner != null) { long newPos = _Inner.Seek(offset, origin); // Sync the counter SyncCounter(); return newPos; } return -1; } public override void SetLength(long value) { if (_Inner != null) { _Inner.SetLength(value); } } protected override void Dispose(bool disposing) { _Inner.Dispose(); _Cipher.Dispose(); base.Dispose(disposing); } private void SyncCounter() { if (_Cipher != null) { // Calculate the counter iterations and position needed int iterations = (int)Math.Floor(_Inner.Position / (decimal)_Cipher.InputBlockSize); int counterPos = (int)(_Inner.Position % _Cipher.InputBlockSize); // Are we out of sync with the cipher? if (_Cipher.Iterations != iterations + 1 || _Cipher.CounterPosition != counterPos) { // Reset the current counter _Cipher.ResetCounter(); // Iterate the counter to the current position for (int i = 0; i < iterations; i++) { _Cipher.IncrementCounter(); } // Encrypt the counter _Cipher.EncryptCounter(); // Set the current position of the counter _Cipher.CounterPosition = counterPos; // Increment the counter for the next time _Cipher.IncrementCounter(); } } } } }