Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions src/Npgsql/Internal/TypeHandlers/UnmappedEnumHandler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
Expand All @@ -18,27 +18,26 @@ sealed class UnmappedEnumHandler : TextHandler
{
readonly INpgsqlNameTranslator _nameTranslator;

readonly Dictionary<Enum, string> _enumToLabel = new();
readonly Dictionary<string, Enum> _labelToEnum = new();

Type? _resolvedType;
// Note that a separate instance of UnmappedEnumHandler is created for each PG enum type, so concurrency isn't "really" needed.
// However, in theory multiple different CLR enums may be used with the same PG enum type, and even if there's only one, we only know
// about it late (after construction), when the user actually reads/writes with one. So this handler is fully thread-safe.
readonly ConcurrentDictionary<Type, TypeRecord> _types = new();

internal UnmappedEnumHandler(PostgresEnumType pgType, INpgsqlNameTranslator nameTranslator, Encoding encoding)
: base(pgType, encoding)
=> _nameTranslator = nameTranslator;

#region Read

protected internal override async ValueTask<TAny> ReadCustom<TAny>(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription = null)
protected internal override async ValueTask<TAny> ReadCustom<TAny>(NpgsqlReadBuffer buf, int len, bool async, FieldDescription? fieldDescription)
{
var s = await base.Read(buf, len, async, fieldDescription);
if (typeof(TAny) == typeof(string))
return (TAny)(object)s;

if (_resolvedType != typeof(TAny))
Map(typeof(TAny));
var typeRecord = GetTypeRecord(typeof(TAny));

if (!_labelToEnum.TryGetValue(s, out var value))
if (!typeRecord.LabelToEnum.TryGetValue(s, out var value))
throw new InvalidCastException($"Received enum value '{s}' from database which wasn't found on enum {typeof(TAny)}");

// TODO: Avoid boxing
Expand Down Expand Up @@ -66,11 +65,11 @@ int ValidateAndGetLength(object value, ref NpgsqlLengthCache? lengthCache, Npgsq
var type = value.GetType();
if (type == typeof(string))
return base.ValidateAndGetLength((string)value, ref lengthCache, parameter);
if (_resolvedType != type)
Map(type);

var typeRecord = GetTypeRecord(type);

// TODO: Avoid boxing
return _enumToLabel.TryGetValue((Enum)value, out var str)
return typeRecord.EnumToLabel.TryGetValue((Enum)value, out var str)
? base.ValidateAndGetLength(str, ref lengthCache, parameter)
: throw new InvalidCastException($"Can't write value {value} as enum {type}");
}
Expand Down Expand Up @@ -104,11 +103,11 @@ internal Task Write(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? leng
var type = value.GetType();
if (type == typeof(string))
return base.Write((string)value, buf, lengthCache, parameter, async, cancellationToken);
if (_resolvedType != type)
Map(type);

var typeRecord = GetTypeRecord(type);

// TODO: Avoid boxing
if (!_enumToLabel.TryGetValue((Enum)value, out var str))
if (!typeRecord.EnumToLabel.TryGetValue((Enum)value, out var str))
throw new InvalidCastException($"Can't write value {value} as enum {type}");
return base.Write(str, buf, lengthCache, parameter, async, cancellationToken);
}
Expand All @@ -117,25 +116,34 @@ internal Task Write(object value, NpgsqlWriteBuffer buf, NpgsqlLengthCache? leng

#region Misc

void Map([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] Type type)
TypeRecord GetTypeRecord(Type type)
{
Debug.Assert(_resolvedType != type);
#if NETSTANDARD2_0
return _types.GetOrAdd(type, t => CreateTypeRecord(t, _nameTranslator));
#else
return _types.GetOrAdd(type, static (t, translator) => CreateTypeRecord(t, translator), _nameTranslator);
#endif
}

_enumToLabel.Clear();
_labelToEnum.Clear();
static TypeRecord CreateTypeRecord(Type type, INpgsqlNameTranslator nameTranslator)
{
var enumToLabel = new Dictionary<Enum, string>();
var labelToEnum = new Dictionary<string, Enum>();

foreach (var field in type.GetFields(BindingFlags.Static | BindingFlags.Public))
{
var attribute = (PgNameAttribute?)field.GetCustomAttributes(typeof(PgNameAttribute), false).FirstOrDefault();
var enumName = attribute?.PgName ?? _nameTranslator.TranslateMemberName(field.Name);
var enumName = attribute?.PgName ?? nameTranslator.TranslateMemberName(field.Name);
var enumValue = (Enum)field.GetValue(null)!;

_enumToLabel[enumValue] = enumName;
_labelToEnum[enumName] = enumValue;
enumToLabel[enumValue] = enumName;
labelToEnum[enumName] = enumValue;
}

_resolvedType = type;
return new(enumToLabel, labelToEnum);
}

#endregion

record struct TypeRecord(Dictionary<Enum, string> EnumToLabel, Dictionary<string, Enum> LabelToEnum);
}
114 changes: 27 additions & 87 deletions test/Npgsql.Tests/Types/EnumTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,14 @@
using Npgsql.PostgresTypes;
using NpgsqlTypes;
using NUnit.Framework;
using static Npgsql.Util.Statics;
using static Npgsql.Tests.TestUtil;

namespace Npgsql.Tests.Types;

public class EnumTests : MultiplexingTestBase
{
enum Mood { Sad, Ok, Happy }

[PgName("explicitly_named_mood")]
enum MoodUnmapped { Sad, Ok, Happy };

[Test]
public async Task Unmapped_enum()
{
await using var connection = await OpenConnectionAsync();
await using var _ = await GetTempTypeName(connection, out var type);
await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')");
await connection.ReloadTypesAsync();

await AssertType(connection, Mood.Happy, "happy", type, npgsqlDbType: null, isDefault: false);
}
enum AnotherEnum { Value1, Value2 }

[Test]
public async Task Data_source_mapping()
Expand Down Expand Up @@ -73,78 +59,6 @@ public async Task Array()
await AssertType(dataSource, new[] { Mood.Ok, Mood.Happy }, "{ok,happy}", type + "[]", npgsqlDbType: null);
}

[Test]
public async Task Read_unmapped_enum_as_string()
{
using var conn = new NpgsqlConnection(ConnectionString);
conn.Open();
await using var _ = await GetTempTypeName(conn, out var type);

await conn.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('Sad', 'Ok', 'Happy')");
conn.ReloadTypes();
using var cmd = new NpgsqlCommand($"SELECT 'Sad'::{type}, ARRAY['Ok', 'Happy']::{type}[]", conn);
using var reader = await cmd.ExecuteReaderAsync();
reader.Read();
Assert.That(reader[0], Is.EqualTo("Sad"));
Assert.That(reader.GetDataTypeName(0), Is.EqualTo($"public.{type}"));
Assert.That(reader[1], Is.EqualTo(new[] { "Ok", "Happy" }));
}

[Test, Description("Test that a c# string can be written to a backend enum when DbType is unknown")]
public async Task Write_string_to_backend_enum()
{
await using var conn = await OpenConnectionAsync();
await using var _ = await GetTempTypeName(conn, out var type);
await using var __ = await GetTempTableName(conn, out var table);
await conn.ExecuteNonQueryAsync($@"
CREATE TYPE {type} AS ENUM ('Banana', 'Apple', 'Orange');
CREATE TABLE {table} (id SERIAL, value1 {type}, value2 {type});");
await conn.ReloadTypesAsync();
const string expected = "Banana";
using var cmd = new NpgsqlCommand($"INSERT INTO {table} (id, value1, value2) VALUES (default, @p1, @p2);", conn);
cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Unknown, expected);
var p2 = new NpgsqlParameter("p1", NpgsqlDbType.Unknown) {Value = expected};
cmd.Parameters.Add(p2);
await cmd.ExecuteNonQueryAsync();
}

[Test, NonParallelizable]
public async Task Write_unmapped_enum()
{
await using var conn = await OpenConnectionAsync();
await conn.ExecuteNonQueryAsync(@"
DROP TYPE IF EXISTS explicitly_named_mood;
CREATE TYPE explicitly_named_mood AS ENUM ('sad', 'ok', 'happy')");

await conn.ReloadTypesAsync();

await using var cmd = new NpgsqlCommand($"SELECT @p::text", conn)
{
Parameters = { new("p", MoodUnmapped.Happy) }
};

await using var reader = await cmd.ExecuteReaderAsync();
await reader.ReadAsync();

Assert.AreEqual("happy", reader.GetFieldValue<string>(0));
}

[Test, Description("Tests that a a C# enum an be written to an enum backend when passed as dbUnknown")]
public async Task Write_enum_as_NpgsqlDbType_Unknown()
{
await using var conn = await OpenConnectionAsync();
await using var _ = await GetTempTypeName(conn, out var type);
await using var __ = await GetTempTableName(conn, out var table);
await conn.ExecuteNonQueryAsync($@"
CREATE TYPE {type} AS ENUM ('Sad', 'Ok', 'Happy');
CREATE TABLE {table} (value1 {type})");
await conn.ReloadTypesAsync();
var expected = Mood.Happy;
using var cmd = new NpgsqlCommand($"INSERT INTO {table} (value1) VALUES (@p1);", conn);
cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Unknown, expected);
await cmd.ExecuteNonQueryAsync();
}

[Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")]
public async Task Name_translation_default_snake_case()
{
Expand Down Expand Up @@ -177,6 +91,32 @@ public async Task Name_translation_null()
await AssertType(dataSource, NameTranslationEnum.SomeClrName, "some_database_name", type, npgsqlDbType: null);
}

[Test]
public async Task Unmapped_enum_as_clr_enum()
{
await using var connection = await OpenConnectionAsync();
await using var _ = await GetTempTypeName(connection, out var type1);
await using var __ = await GetTempTypeName(connection, out var type2);
await connection.ExecuteNonQueryAsync(@$"
CREATE TYPE {type1} AS ENUM ('sad', 'ok', 'happy');
CREATE TYPE {type2} AS ENUM ('value1', 'value2');");
await connection.ReloadTypesAsync();

await AssertType(connection, Mood.Happy, "happy", type1, npgsqlDbType: null, isDefault: false);
await AssertType(connection, AnotherEnum.Value2, "value2", type2, npgsqlDbType: null, isDefault: false);
}

[Test]
public async Task Unmapped_enum_as_string()
{
await using var connection = await OpenConnectionAsync();
await using var _ = await GetTempTypeName(connection, out var type);
await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')");
await connection.ReloadTypesAsync();

await AssertType(connection, "happy", "happy", type, npgsqlDbType: null, isDefaultForWriting: false);
}

enum NameTranslationEnum
{
Simple,
Expand Down