Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sandbox/DynamicCodeDumper/DynamicCodeDumper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@
<Compile Include="..\..\src\MessagePack\SafeBitConverter.cs">
<Link>Code\SafeBitConverter.cs</Link>
</Compile>
<Compile Include="..\..\src\MessagePack\SipHash.cs">
<Link>Code\SipHash.cs</Link>
</Compile>
<Compile Include="..\..\src\MessagePack\MessagePackCode.cs">
<Link>Code\MessagePackCode.cs</Link>
</Compile>
Expand Down
235 changes: 109 additions & 126 deletions src/MessagePack/MessagePackSecurity.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using MessagePack.Formatters;
using MessagePack.Internal;

Expand All @@ -31,6 +33,8 @@ public class MessagePackSecurity
MaximumObjectGraphDepth = 500,
};

private static readonly SipHash Hash = new();

private readonly ObjectFallbackEqualityComparer objectFallbackEqualityComparer;

private MessagePackSecurity()
Expand Down Expand Up @@ -138,62 +142,72 @@ public IEqualityComparer GetEqualityComparer()
return this.HashCollisionResistant ? GetHashCollisionResistantEqualityComparer() : EqualityComparer<object>.Default;
}

private class HashResistantCache<T>
{
internal static readonly IEqualityComparer<T>? EqualityComparer;

static HashResistantCache()
{
// We have to specially handle some 32-bit types (e.g. float) where multiple in-memory representations should hash to the same value.
// Any type supported by the PrimitiveObjectFormatter should be added here if supporting it as a key in a collection makes sense.
EqualityComparer =
typeof(T) == typeof(bool) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<bool>.Instance :
typeof(T) == typeof(char) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<char>.Instance :
typeof(T) == typeof(sbyte) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<sbyte>.Instance :
typeof(T) == typeof(byte) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<byte>.Instance :
typeof(T) == typeof(short) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<short>.Instance :
typeof(T) == typeof(ushort) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<ushort>.Instance :
typeof(T) == typeof(int) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<int>.Instance :
typeof(T) == typeof(uint) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<uint>.Instance :
typeof(T) == typeof(long) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<long>.Instance :
typeof(T) == typeof(ulong) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<ulong>.Instance :
typeof(T) == typeof(Guid) ? (IEqualityComparer<T>)CollisionResistantHasherUnmanaged<Guid>.Instance :

// Data types that are managed or have multiple in-memory representations for equivalent values:
typeof(T) == typeof(float) ? (IEqualityComparer<T>)SingleEqualityComparer.Instance :
typeof(T) == typeof(double) ? (IEqualityComparer<T>)DoubleEqualityComparer.Instance :
typeof(T) == typeof(string) ? (IEqualityComparer<T>)StringEqualityComparer.Instance :
typeof(T) == typeof(DateTime) ? (IEqualityComparer<T>)DateTimeEqualityComparer.Instance :
typeof(T) == typeof(DateTimeOffset) ? (IEqualityComparer<T>)DateTimeOffsetEqualityComparer.Instance :

// Call out each primitive behind an enum explicitly to avoid dynamically generating code.
typeof(T).GetTypeInfo().IsEnum && typeof(T).GetTypeInfo().GetEnumUnderlyingType() is Type underlying ? (
underlying == typeof(byte) ? CollisionResistantEnumHasher<T, byte>.Instance :
underlying == typeof(sbyte) ? CollisionResistantEnumHasher<T, sbyte>.Instance :
underlying == typeof(ushort) ? CollisionResistantEnumHasher<T, ushort>.Instance :
underlying == typeof(short) ? CollisionResistantEnumHasher<T, short>.Instance :
underlying == typeof(uint) ? CollisionResistantEnumHasher<T, uint>.Instance :
underlying == typeof(int) ? CollisionResistantEnumHasher<T, int>.Instance :
underlying == typeof(ulong) ? CollisionResistantEnumHasher<T, ulong>.Instance :
underlying == typeof(long) ? CollisionResistantEnumHasher<T, long>.Instance :
null) :

// Failsafe. If we don't recognize the type, don't assume we have a good, secure hash function for it.
null;
}
}

/// <summary>
/// Returns a hash collision resistant equality comparer.
/// </summary>
/// <typeparam name="T">The type of key that will be hashed in the collection.</typeparam>
/// <returns>A hash collision resistant equality comparer.</returns>
protected virtual IEqualityComparer<T> GetHashCollisionResistantEqualityComparer<T>()
{
IEqualityComparer<T>? result = null;
if (typeof(T).GetTypeInfo().IsEnum)
if (HashResistantCache<T>.EqualityComparer is { } result)
{
Type underlyingType = typeof(T).GetTypeInfo().GetEnumUnderlyingType();
result =
underlyingType == typeof(sbyte) ? CollisionResistantHasher<T>.Instance :
underlyingType == typeof(byte) ? CollisionResistantHasher<T>.Instance :
underlyingType == typeof(short) ? CollisionResistantHasher<T>.Instance :
underlyingType == typeof(ushort) ? CollisionResistantHasher<T>.Instance :
underlyingType == typeof(int) ? CollisionResistantHasher<T>.Instance :
underlyingType == typeof(uint) ? CollisionResistantHasher<T>.Instance :
null;
return result;
}
else

if (typeof(T) == typeof(object))
{
// For anything 32-bits and under, our fallback base secure hasher is usually adequate since it makes the hash unpredictable.
// We should have special implementations for any value that is larger than 32-bits in order to make sure
// that all the data gets hashed securely rather than trivially and predictably compressed into 32-bits before being hashed.
// We also have to specially handle some 32-bit types (e.g. float) where multiple in-memory representations should hash to the same value.
// Any type supported by the PrimitiveObjectFormatter should be added here if supporting it as a key in a collection makes sense.
result =

// 32-bits or smaller:
typeof(T) == typeof(bool) ? CollisionResistantHasher<T>.Instance :
typeof(T) == typeof(char) ? CollisionResistantHasher<T>.Instance :
typeof(T) == typeof(sbyte) ? CollisionResistantHasher<T>.Instance :
typeof(T) == typeof(byte) ? CollisionResistantHasher<T>.Instance :
typeof(T) == typeof(short) ? CollisionResistantHasher<T>.Instance :
typeof(T) == typeof(ushort) ? CollisionResistantHasher<T>.Instance :
typeof(T) == typeof(int) ? CollisionResistantHasher<T>.Instance :
typeof(T) == typeof(uint) ? CollisionResistantHasher<T>.Instance :

// Larger than 32-bits (or otherwise require special handling):
typeof(T) == typeof(long) ? (IEqualityComparer<T>)Int64EqualityComparer.Instance :
typeof(T) == typeof(ulong) ? (IEqualityComparer<T>)UInt64EqualityComparer.Instance :
typeof(T) == typeof(float) ? (IEqualityComparer<T>)SingleEqualityComparer.Instance :
typeof(T) == typeof(double) ? (IEqualityComparer<T>)DoubleEqualityComparer.Instance :
typeof(T) == typeof(string) ? (IEqualityComparer<T>)StringEqualityComparer.Instance :
typeof(T) == typeof(Guid) ? (IEqualityComparer<T>)GuidEqualityComparer.Instance :
typeof(T) == typeof(DateTime) ? (IEqualityComparer<T>)DateTimeEqualityComparer.Instance :
typeof(T) == typeof(DateTimeOffset) ? (IEqualityComparer<T>)DateTimeOffsetEqualityComparer.Instance :
typeof(T) == typeof(object) ? (IEqualityComparer<T>)this.objectFallbackEqualityComparer :
null;
return (IEqualityComparer<T>)this.objectFallbackEqualityComparer;
}

// Any type we don't explicitly whitelist here shouldn't be allowed to use as the key in a hash-based collection since it isn't known to be hash resistant.
// This method can of course be overridden to add more hash collision resistant type support, or the deserializing party can indicate that the data is Trusted
// so that this method doesn't even get called.
return result ?? throw new TypeAccessException($"No hash-resistant equality comparer available for type: {typeof(T)}");
throw new TypeAccessException($"No hash-resistant equality comparer available for type: {typeof(T)}");
}

/// <summary>
Expand Down Expand Up @@ -233,25 +247,41 @@ public void DepthStep(ref MessagePackReader reader)
/// </remarks>
protected virtual MessagePackSecurity Clone() => new MessagePackSecurity(this);

private static int SecureHash<T>(T value)
where T : unmanaged
{
Span<T> span = stackalloc T[1];
span[0] = value;
return unchecked((int)Hash.Compute(MemoryMarshal.Cast<T, byte>(span)));
}

private static int SecureHash(ReadOnlySpan<byte> data) => unchecked((int)Hash.Compute(data));

/// <summary>
/// A hash collision resistant implementation of <see cref="IEqualityComparer{T}"/>.
/// </summary>
/// <typeparam name="T">The type of key that will be hashed.</typeparam>
private class CollisionResistantHasher<T> : IEqualityComparer<T>, IEqualityComparer
private abstract class CollisionResistantHasher<T> : IEqualityComparer<T>, IEqualityComparer
{
internal static readonly CollisionResistantHasher<T> Instance = new CollisionResistantHasher<T>();

public bool Equals(T? x, T? y) => EqualityComparer<T?>.Default.Equals(x, y);

bool IEqualityComparer.Equals(object? x, object? y) => ((IEqualityComparer)EqualityComparer<T>.Default).Equals(x, y);

public int GetHashCode(object obj) => this.GetHashCode((T)obj);

public virtual int GetHashCode(T value) => HashCode.Combine(value);
public abstract int GetHashCode(T value);
}

private class CollisionResistantHasherUnmanaged<T> : CollisionResistantHasher<T>
where T : unmanaged
{
internal static readonly CollisionResistantHasherUnmanaged<T> Instance = new();

public override int GetHashCode(T value) => SecureHash(value);
}

/// <summary>
/// A special hash-resistent equality comparer that defers picking the actual implementation
/// A special hash-resistant equality comparer that defers picking the actual implementation
/// till it can check the runtime type of each value to be hashed.
/// </summary>
private class ObjectFallbackEqualityComparer : IEqualityComparer<object>, IEqualityComparer
Expand Down Expand Up @@ -304,116 +334,69 @@ public int GetHashCode(object value)
}
}

private class UInt64EqualityComparer : CollisionResistantHasher<ulong>
private class SingleEqualityComparer : CollisionResistantHasherUnmanaged<float>
{
internal static new readonly UInt64EqualityComparer Instance = new UInt64EqualityComparer();

public override int GetHashCode(ulong value) => HashCode.Combine((uint)(value >> 32), unchecked((uint)value));
}

private class Int64EqualityComparer : CollisionResistantHasher<long>
{
internal static new readonly Int64EqualityComparer Instance = new Int64EqualityComparer();

public override int GetHashCode(long value) => HashCode.Combine((int)(value >> 32), unchecked((int)value));
}

private class SingleEqualityComparer : CollisionResistantHasher<float>
{
internal static new readonly SingleEqualityComparer Instance = new SingleEqualityComparer();
internal static new readonly SingleEqualityComparer Instance = new();

public override unsafe int GetHashCode(float value)
{
// Special check for 0.0 so that the hash of 0.0 and -0.0 will equal.
if (value == 0.0f)
=> base.GetHashCode(value switch
{
return HashCode.Combine(0);
}

// Standardize on the binary representation of NaN prior to hashing.
if (float.IsNaN(value))
{
value = float.NaN;
}

int l = *(int*)&value;
return l;
}
0.0f => 0, // Special check for 0.0 so that the hash of 0.0 and -0.0 will equal.
float.NaN => float.NaN, // Standardize on the binary representation of NaN prior to hashing.
_ => value,
});
}

private class DoubleEqualityComparer : CollisionResistantHasher<double>
private class DoubleEqualityComparer : CollisionResistantHasherUnmanaged<double>
{
internal static new readonly DoubleEqualityComparer Instance = new DoubleEqualityComparer();
internal static new readonly DoubleEqualityComparer Instance = new();

public override unsafe int GetHashCode(double value)
{
// Special check for 0.0 so that the hash of 0.0 and -0.0 will equal.
if (value == 0.0)
=> base.GetHashCode(value switch
{
return HashCode.Combine(0);
}
0.0 => 0, // Special check for 0.0 so that the hash of 0.0 and -0.0 will equal.
double.NaN => double.NaN, // Standardize on the binary representation of NaN prior to hashing.
_ => value,
});
}

// Standardize on the binary representation of NaN prior to hashing.
if (double.IsNaN(value))
{
value = double.NaN;
}
private class DateTimeEqualityComparer : CollisionResistantHasherUnmanaged<DateTime>
{
internal static new readonly DateTimeEqualityComparer Instance = new();

long l = *(long*)&value;
return HashCode.Combine((int)(l >> 32), unchecked((int)l));
}
public override unsafe int GetHashCode(DateTime value) => SecureHash(value.Ticks);
}

private class GuidEqualityComparer : CollisionResistantHasher<Guid>
private class DateTimeOffsetEqualityComparer : CollisionResistantHasherUnmanaged<DateTimeOffset>
{
internal static new readonly GuidEqualityComparer Instance = new GuidEqualityComparer();

public override unsafe int GetHashCode(Guid value)
{
var hash = default(HashCode);
int* pGuid = (int*)&value;
for (int i = 0; i < sizeof(Guid) / sizeof(int); i++)
{
hash.Add(pGuid[i]);
}
internal static new readonly DateTimeOffsetEqualityComparer Instance = new();

return hash.ToHashCode();
}
public override unsafe int GetHashCode(DateTimeOffset value) => SecureHash(value.UtcDateTime.Ticks);
}

private class StringEqualityComparer : CollisionResistantHasher<string>
{
internal static new readonly StringEqualityComparer Instance = new StringEqualityComparer();
internal static readonly StringEqualityComparer Instance = new();

public override int GetHashCode(string value)
{
#if NETCOREAPP
// .NET Core already has a secure string hashing function. Just use it.
return value?.GetHashCode() ?? 0;
#else
var hash = default(HashCode);
for (int i = 0; i < value.Length; i++)
{
hash.Add(value[i]);
}

return hash.ToHashCode();
#endif
// The Cast call could result in OverflowException at runtime if value is greater than 1bn chars in length.
return SecureHash(MemoryMarshal.Cast<char, byte>(value.AsSpan()));
}
}

private class DateTimeEqualityComparer : CollisionResistantHasher<DateTime>
private class CollisionResistantEnumHasher<TEnum, TUnderlying> : IEqualityComparer<TEnum>, IEqualityComparer
where TUnderlying : unmanaged
{
internal static new readonly DateTimeEqualityComparer Instance = new DateTimeEqualityComparer();
internal static readonly CollisionResistantEnumHasher<TEnum, TUnderlying> Instance = new();

public override unsafe int GetHashCode(DateTime value) => HashCode.Combine((int)(value.Ticks >> 32), unchecked((int)value.Ticks), value.Kind);
}
public bool Equals(TEnum? x, TEnum? y) => EqualityComparer<TEnum?>.Default.Equals(x, y);

private class DateTimeOffsetEqualityComparer : CollisionResistantHasher<DateTimeOffset>
{
internal static new readonly DateTimeOffsetEqualityComparer Instance = new DateTimeOffsetEqualityComparer();
public int GetHashCode(TEnum obj) => SecureHash(Unsafe.As<TEnum, TUnderlying>(ref obj));

bool IEqualityComparer.Equals(object? x, object? y) => x is TEnum e1 && y is TEnum e2 && Equals(e1, e2);

public override unsafe int GetHashCode(DateTimeOffset value) => HashCode.Combine((int)(value.UtcTicks >> 32), unchecked((int)value.UtcTicks));
int IEqualityComparer.GetHashCode(object obj) => GetHashCode((TEnum)obj);
}
}
}
Loading