1
0
mirror of https://git.teknik.io/Teknikode/Teknik.git synced 2023-08-02 14:16:22 +02:00
Teknik/Utilities/Cryptography/AesCounterStream.cs

297 lines
8.4 KiB
C#

using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Teknik.Utilities.Cryptography
{
public class AesCounterStream : Stream
{
private Stream _Inner;
private CounterModeCryptoTransform _Cipher;
/// <summary>
/// Performs Encryption or Decryption on a stream with the given Key and IV
///
/// Cipher is AES-256 in CTR mode with no padding
/// </summary>
/// <param name="stream"></param>
/// <param name="encrypt"></param>
/// <param name="key"></param>
/// <param name="iv"></param>
public AesCounterStream(Stream stream, bool encrypt, byte[] key, byte[] iv)
{
_Inner = stream;
// Create the Aes Cipher
using AesCounterMode aes = new AesCounterMode(iv);
if (encrypt)
{
_Cipher = (CounterModeCryptoTransform)aes.CreateEncryptor(key, iv); // Encrypt
}
else
{
_Cipher = (CounterModeCryptoTransform)aes.CreateDecryptor(key, iv); // Decrypt
}
// Sync the counter
SyncCounter();
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
if (_Inner != null && CanRead)
{
int processed = 0;
// Read the data from the stream
int bytesRead = await _Inner.ReadAsync(buffer);
if (bytesRead > 0)
{
// Process the read buffer
processed = _Cipher.TransformBlock(buffer.Span, 0, bytesRead);
}
// Do we have more?
if (processed < bytesRead)
{
// Finalize the cipher
var finalProcessed = _Cipher.TransformFinalBlock(buffer.Span, processed, bytesRead);
if (finalProcessed > 0)
processed += finalProcessed;
}
return processed;
}
return -1;
}
public override int Read(byte[] buffer, int offset, int count)
{
if (_Inner != null && CanRead)
{
Span<byte> readBuf = buffer;
int processed = 0;
// Read the data from the stream
int bytesRead = _Inner.Read(readBuf.Slice(offset, count));
if (bytesRead > 0)
{
// Process the read buffer
processed = _Cipher.TransformBlock(readBuf, offset, bytesRead);
}
// Do we have more?
if (processed < bytesRead)
{
// Finalize the cipher
var finalProcessed = _Cipher.TransformFinalBlock(readBuf, processed + offset, bytesRead);
if (finalProcessed > 0)
processed += finalProcessed;
}
return processed;
}
return -1;
}
public override int Read(Span<byte> buffer)
{
if (_Inner != null && CanRead)
{
int processed = 0;
// Read the data from the stream
int bytesRead = _Inner.Read(buffer);
if (bytesRead > 0)
{
// Process the read buffer
processed = _Cipher.TransformBlock(buffer, 0, bytesRead);
}
// Do we have more?
if (processed < bytesRead)
{
// Finalize the cipher
var finalProcessed = _Cipher.TransformFinalBlock(buffer, processed, bytesRead);
if (finalProcessed > 0)
processed += finalProcessed;
}
return processed;
}
return -1;
}
public override void Write(byte[] buffer, int offset, int count)
{
if (_Inner != null && CanWrite)
{
// Process the cipher
Memory<byte> output = buffer;
// Process the buffer
int processed = _Cipher.TransformBlock(output.Span, offset, count);
// Do we have more?
if (processed < count)
{
// Finalize the cipher
var finalProcessed = _Cipher.TransformFinalBlock(output.Span, processed + offset, count);
if (finalProcessed > 0)
processed += finalProcessed;
}
ReadOnlyMemory<byte> readOnlyOutput = buffer;
_Inner.Write(readOnlyOutput.Span);
}
}
public override bool CanRead
{
get
{
if (_Inner != null)
{
return _Inner.CanRead;
}
return false;
}
}
public override bool CanSeek
{
get
{
if (_Inner != null)
{
return _Inner.CanSeek;
}
return false;
}
}
public override bool CanWrite
{
get
{
if (_Inner != null)
{
return _Inner.CanWrite;
}
return false;
}
}
public override long Length
{
get
{
if (_Inner != null)
{
return _Inner.Length;
}
return -1;
}
}
public override long Position
{
get
{
if (_Inner != null)
{
return _Inner.Position;
}
return -1;
}
set
{
if (_Inner != null)
{
_Inner.Position = value;
// Sync the counter
SyncCounter();
}
}
}
public override void Flush()
{
if (_Inner != null)
{
_Inner.Flush();
}
}
public override long Seek(long offset, SeekOrigin origin)
{
if (_Inner != null)
{
long newPos = _Inner.Seek(offset, origin);
// Sync the counter
SyncCounter();
return newPos;
}
return -1;
}
public override void SetLength(long value)
{
if (_Inner != null)
{
_Inner.SetLength(value);
}
}
protected override void Dispose(bool disposing)
{
_Cipher.Dispose();
_Inner.Dispose();
base.Dispose(disposing);
}
public override async ValueTask DisposeAsync()
{
await _Inner.DisposeAsync();
await base.DisposeAsync();
}
private void SyncCounter()
{
if (_Cipher != null)
{
// Calculate the counter iterations and position needed
int iterations = (int)Math.Floor(_Inner.Position / (decimal)_Cipher.InputBlockSize);
int counterPos = (int)(_Inner.Position % _Cipher.InputBlockSize);
// Are we out of sync with the cipher?
if (_Cipher.Iterations != iterations + 1 || _Cipher.CounterPosition != counterPos)
{
// Reset the current counter
_Cipher.ResetCounter();
// Iterate the counter to the current position
for (int i = 0; i < iterations; i++)
{
_Cipher.IncrementCounter();
}
// Encrypt the counter
_Cipher.EncryptCounter();
// Set the current position of the counter
_Cipher.CounterPosition = counterPos;
// Increment the counter for the next time
_Cipher.IncrementCounter();
}
}
}
}
}