Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/Npgsql/Internal/EncryptionHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Npgsql.Properties;
using Npgsql.Util;

namespace Npgsql.Internal;

class EncryptionHandler
{
public virtual bool SupportEncryption => false;

public virtual Func<X509Certificate2?>? RootCertificateCallback
{
get => throw new InvalidOperationException(NpgsqlStrings.EncryptionDisabled);
set => throw new InvalidOperationException(NpgsqlStrings.EncryptionDisabled);
}

public virtual Task NegotiateEncryption(NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout, bool async, bool isFirstAttempt)
=> throw new InvalidOperationException(NpgsqlStrings.EncryptionDisabled);

public virtual void AuthenticateSASLSha256Plus(NpgsqlConnector connector, ref string mechanism, ref string cbindFlag, ref string cbind,
ref bool successfulBind)
=> throw new InvalidOperationException(NpgsqlStrings.EncryptionDisabled);
}

sealed class RealEncryptionHandler : EncryptionHandler
{
public override bool SupportEncryption => true;

public override Func<X509Certificate2?>? RootCertificateCallback { get; set; }

public override Task NegotiateEncryption(NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout, bool async, bool isFirstAttempt)
=> connector.NegotiateEncryption(sslMode, timeout, async, isFirstAttempt);

public override void AuthenticateSASLSha256Plus(NpgsqlConnector connector, ref string mechanism, ref string cbindFlag, ref string cbind,
ref bool successfulBind)
=> connector.AuthenticateSASLSha256Plus(ref mechanism, ref cbindFlag, ref cbind, ref successfulBind);
}
118 changes: 61 additions & 57 deletions src/Npgsql/Internal/NpgsqlConnector.Auth.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ async Task Authenticate(string username, NpgsqlTimeout timeout, bool async, Canc
break;

case AuthenticationRequestType.AuthenticationSASL:
await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, async, cancellationToken);
await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, async,
cancellationToken);
break;

case AuthenticationRequestType.AuthenticationGSS:
Expand Down Expand Up @@ -67,7 +68,7 @@ async Task AuthenticateCleartext(string username, bool async, CancellationToken
await Flush(async, cancellationToken);
}

async Task AuthenticateSASL(List<string> mechanisms, string username, bool async, CancellationToken cancellationToken = default)
async Task AuthenticateSASL(List<string> mechanisms, string username, bool async, CancellationToken cancellationToken)
{
// At the time of writing PostgreSQL only supports SCRAM-SHA-256 and SCRAM-SHA-256-PLUS
var supportsSha256 = mechanisms.Contains("SCRAM-SHA-256");
Expand All @@ -82,61 +83,7 @@ async Task AuthenticateSASL(List<string> mechanisms, string username, bool async
var successfulBind = false;

if (supportsSha256Plus)
{
var sslStream = (SslStream)_stream;
if (sslStream.RemoteCertificate is null)
{
ConnectionLogger.LogWarning("Remote certificate null, falling back to SCRAM-SHA-256");
}
else
{
using var remoteCertificate = new X509Certificate2(sslStream.RemoteCertificate);
// Checking for hashing algorithms
HashAlgorithm? hashAlgorithm = null;
var algorithmName = remoteCertificate.SignatureAlgorithm.FriendlyName;
if (algorithmName is null)
{
ConnectionLogger.LogWarning("Signature algorithm was null, falling back to SCRAM-SHA-256");
}
else if (algorithmName.StartsWith("sha1", StringComparison.OrdinalIgnoreCase) ||
algorithmName.StartsWith("md5", StringComparison.OrdinalIgnoreCase) ||
algorithmName.StartsWith("sha256", StringComparison.OrdinalIgnoreCase))
{
hashAlgorithm = SHA256.Create();
}
else if (algorithmName.StartsWith("sha384", StringComparison.OrdinalIgnoreCase))
{
hashAlgorithm = SHA384.Create();
}
else if (algorithmName.StartsWith("sha512", StringComparison.OrdinalIgnoreCase))
{
hashAlgorithm = SHA512.Create();
}
else
{
ConnectionLogger.LogWarning(
$"Support for signature algorithm {algorithmName} is not yet implemented, falling back to SCRAM-SHA-256");
}

if (hashAlgorithm != null)
{
using var _ = hashAlgorithm;

// RFC 5929
mechanism = "SCRAM-SHA-256-PLUS";
// PostgreSQL only supports tls-server-end-point binding
cbindFlag = "p=tls-server-end-point";
// SCRAM-SHA-256-PLUS depends on using ssl stream, so it's fine
var cbindFlagBytes = Encoding.UTF8.GetBytes($"{cbindFlag},,");

var certificateHash = hashAlgorithm.ComputeHash(remoteCertificate.GetRawCertData());
var cbindBytes = cbindFlagBytes.Concat(certificateHash).ToArray();
cbind = Convert.ToBase64String(cbindBytes);
successfulBind = true;
IsScramPlus = true;
}
}
}
DataSource.EncryptionHandler.AuthenticateSASLSha256Plus(this, ref mechanism, ref cbindFlag, ref cbind, ref successfulBind);

if (!successfulBind && supportsSha256)
{
Expand Down Expand Up @@ -217,6 +164,63 @@ static string GetNonce()
}
}

internal void AuthenticateSASLSha256Plus(ref string mechanism, ref string cbindFlag, ref string cbind,
ref bool successfulBind)
{
var sslStream = (SslStream)_stream;
if (sslStream.RemoteCertificate is null)
{
ConnectionLogger.LogWarning("Remote certificate null, falling back to SCRAM-SHA-256");
return;
}

using var remoteCertificate = new X509Certificate2(sslStream.RemoteCertificate);
// Checking for hashing algorithms
HashAlgorithm? hashAlgorithm = null;
var algorithmName = remoteCertificate.SignatureAlgorithm.FriendlyName;
if (algorithmName is null)
{
ConnectionLogger.LogWarning("Signature algorithm was null, falling back to SCRAM-SHA-256");
}
else if (algorithmName.StartsWith("sha1", StringComparison.OrdinalIgnoreCase) ||
algorithmName.StartsWith("md5", StringComparison.OrdinalIgnoreCase) ||
algorithmName.StartsWith("sha256", StringComparison.OrdinalIgnoreCase))
{
hashAlgorithm = SHA256.Create();
}
else if (algorithmName.StartsWith("sha384", StringComparison.OrdinalIgnoreCase))
{
hashAlgorithm = SHA384.Create();
}
else if (algorithmName.StartsWith("sha512", StringComparison.OrdinalIgnoreCase))
{
hashAlgorithm = SHA512.Create();
}
else
{
ConnectionLogger.LogWarning(
$"Support for signature algorithm {algorithmName} is not yet implemented, falling back to SCRAM-SHA-256");
}

if (hashAlgorithm != null)
{
using var _ = hashAlgorithm;

// RFC 5929
mechanism = "SCRAM-SHA-256-PLUS";
// PostgreSQL only supports tls-server-end-point binding
cbindFlag = "p=tls-server-end-point";
// SCRAM-SHA-256-PLUS depends on using ssl stream, so it's fine
var cbindFlagBytes = Encoding.UTF8.GetBytes($"{cbindFlag},,");

var certificateHash = hashAlgorithm.ComputeHash(remoteCertificate.GetRawCertData());
var cbindBytes = cbindFlagBytes.Concat(certificateHash).ToArray();
cbind = Convert.ToBase64String(cbindBytes);
successfulBind = true;
IsScramPlus = true;
}
}

#if NET6_0_OR_GREATER
static byte[] Hi(string str, byte[] salt, int count)
=> Rfc2898DeriveBytes.Pbkdf2(str, salt, count, HashAlgorithmName.SHA256, 256 / 8);
Expand Down
11 changes: 4 additions & 7 deletions src/Npgsql/Internal/NpgsqlConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -786,12 +786,9 @@ async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, Cancellat

IsSecure = false;

if ((sslMode is SslMode.Prefer && DataSource.EncryptionNegotiator is not null) ||
if ((sslMode is SslMode.Prefer && DataSource.EncryptionHandler.SupportEncryption) ||
sslMode is SslMode.Require or SslMode.VerifyCA or SslMode.VerifyFull)
{
if (DataSource.EncryptionNegotiator is null)
throw new InvalidOperationException(NpgsqlStrings.EncryptionDisabled);

WriteSslRequest();
await Flush(async, cancellationToken);

Expand All @@ -808,7 +805,7 @@ async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, Cancellat
throw new NpgsqlException("SSL connection requested. No SSL enabled connection from this host is configured.");
break;
case 'S':
await DataSource.EncryptionNegotiator(this, sslMode, timeout, async, isFirstAttempt);
await DataSource.EncryptionHandler.NegotiateEncryption(this, sslMode, timeout, async, isFirstAttempt);
break;
}

Expand Down Expand Up @@ -891,7 +888,7 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout,
if (Settings.RootCertificate is not null)
throw new ArgumentException(NpgsqlStrings.CannotUseSslRootCertificateWithUserCallback);

if (DataSource.RootCertificateCallback is not null)
if (DataSource.EncryptionHandler.RootCertificateCallback is not null)
throw new ArgumentException(NpgsqlStrings.CannotUseValidationRootCertificateCallbackWithUserCallback);

certificateValidationCallback = UserCertificateValidationCallback;
Expand All @@ -904,7 +901,7 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout,
certificateValidationCallback = SslTrustServerValidation;
checkCertificateRevocation = false;
}
else if ((caCert = DataSource.RootCertificateCallback?.Invoke()) is not null ||
else if ((caCert = DataSource.EncryptionHandler.RootCertificateCallback?.Invoke()) is not null ||
(certRootPath = Settings.RootCertificate ??
PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is not null)
{
Expand Down
9 changes: 3 additions & 6 deletions src/Npgsql/NpgsqlDataSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public abstract class NpgsqlDataSource : DbDataSource
/// </summary>
internal NpgsqlDatabaseInfo DatabaseInfo { get; private set; } = null!; // Initialized at bootstrapping

internal Func<NpgsqlConnector, SslMode, NpgsqlTimeout, bool, bool, Task>? EncryptionNegotiator { get; }
internal EncryptionHandler EncryptionHandler { get; }
internal RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; }
internal Action<X509CertificateCollection>? ClientCertificatesCallback { get; }

Expand Down Expand Up @@ -90,7 +90,7 @@ internal NpgsqlDataSource(
Configuration = dataSourceConfig;

(LoggingConfiguration,
EncryptionNegotiator,
EncryptionHandler,
UserCertificateValidationCallback,
ClientCertificatesCallback,
_periodicPasswordProvider,
Expand All @@ -100,8 +100,7 @@ internal NpgsqlDataSource(
_userTypeMappings,
_defaultNameTranslator,
ConnectionInitializer,
ConnectionInitializerAsync,
RootCertificateCallback)
ConnectionInitializerAsync)
= dataSourceConfig;
_connectionLogger = LoggingConfiguration.ConnectionLogger;

Expand Down Expand Up @@ -302,8 +301,6 @@ async Task RefreshPassword()
}

#endregion Password management

internal Func<X509Certificate2?>? RootCertificateCallback { get; }

internal abstract ValueTask<NpgsqlConnector> Get(
NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken);
Expand Down
6 changes: 2 additions & 4 deletions src/Npgsql/NpgsqlDataSourceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
using Npgsql.Internal;
using Npgsql.Internal.TypeHandling;
using Npgsql.Internal.TypeMapping;
using Npgsql.Util;

namespace Npgsql;

sealed record NpgsqlDataSourceConfiguration(
NpgsqlLoggingConfiguration LoggingConfiguration,
Func<NpgsqlConnector, SslMode, NpgsqlTimeout, bool, bool, Task>? EncryptionNegotiator,
EncryptionHandler EncryptionHandler,
RemoteCertificateValidationCallback? UserCertificateValidationCallback,
Action<X509CertificateCollection>? ClientCertificatesCallback,
Func<NpgsqlConnectionStringBuilder, CancellationToken, ValueTask<string>>? PeriodicPasswordProvider,
Expand All @@ -23,5 +22,4 @@ sealed record NpgsqlDataSourceConfiguration(
Dictionary<string, IUserTypeMapping> UserTypeMappings,
INpgsqlNameTranslator DefaultNameTranslator,
Action<NpgsqlConnection>? ConnectionInitializer,
Func<NpgsqlConnection, Task>? ConnectionInitializerAsync,
Func<X509Certificate2?>? RootCertificateCallback);
Func<NpgsqlConnection, Task>? ConnectionInitializerAsync);
15 changes: 6 additions & 9 deletions src/Npgsql/NpgsqlSlimDataSourceBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper
ILoggerFactory? _loggerFactory;
bool _sensitiveDataLoggingEnabled;

Func<NpgsqlConnector, SslMode, NpgsqlTimeout, bool, bool, Task>? _encryptionNegotiator;
EncryptionHandler _encryptionHandler = new();
RemoteCertificateValidationCallback? _userCertificateValidationCallback;
Action<X509CertificateCollection>? _clientCertificatesCallback;
Func<X509Certificate2?>? _rootCertificateCallback;

Func<NpgsqlConnectionStringBuilder, CancellationToken, ValueTask<string>>? _periodicPasswordProvider;
TimeSpan _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval;
Expand Down Expand Up @@ -185,7 +184,7 @@ public NpgsqlSlimDataSourceBuilder UseRootCertificate(X509Certificate2? rootCert
/// </remarks>
public NpgsqlSlimDataSourceBuilder UseRootCertificateCallback(Func<X509Certificate2>? rootCertificateCallback)
{
_rootCertificateCallback = rootCertificateCallback;
_encryptionHandler.RootCertificateCallback = rootCertificateCallback;

return this;
}
Expand Down Expand Up @@ -423,8 +422,7 @@ public NpgsqlSlimDataSourceBuilder EnableRecords()
/// </summary>
public NpgsqlSlimDataSourceBuilder EnableEncryption()
{
_encryptionNegotiator = static (connector, sslMode, timeout, async, isFirstAttempt)
=> connector.NegotiateEncryption(sslMode, timeout, async, isFirstAttempt);
_encryptionHandler = new RealEncryptionHandler();

return this;
}
Expand Down Expand Up @@ -503,7 +501,7 @@ NpgsqlDataSourceConfiguration PrepareConfiguration()
{
ConnectionStringBuilder.PostProcessAndValidate();

if (_encryptionNegotiator is null && (_userCertificateValidationCallback is not null || _clientCertificatesCallback is not null))
if (!_encryptionHandler.SupportEncryption && (_userCertificateValidationCallback is not null || _clientCertificatesCallback is not null))
{
throw new InvalidOperationException(NpgsqlStrings.EncryptionDisabled);
}
Expand All @@ -518,7 +516,7 @@ NpgsqlDataSourceConfiguration PrepareConfiguration()
_loggerFactory is null
? NpgsqlLoggingConfiguration.NullConfiguration
: new NpgsqlLoggingConfiguration(_loggerFactory, _sensitiveDataLoggingEnabled),
_encryptionNegotiator,
_encryptionHandler,
_userCertificateValidationCallback,
_clientCertificatesCallback,
_periodicPasswordProvider,
Expand All @@ -528,8 +526,7 @@ _loggerFactory is null
_userTypeMappings,
DefaultNameTranslator,
_syncConnectionInitializer,
_asyncConnectionInitializer,
_rootCertificateCallback);
_asyncConnectionInitializer);
}

void ValidateMultiHost()
Expand Down