Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 80 additions & 44 deletions src/Npgsql/Internal/NpgsqlConnector.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
Expand Down Expand Up @@ -1146,6 +1147,7 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout,
var checkCertificateRevocation = Settings.CheckCertificateRevocation;

RemoteCertificateValidationCallback? certificateValidationCallback;
X509ChainPolicy? certificateChainPolicy = null;
X509Certificate2Collection? caCerts;
string? certRootPath = null;

Expand All @@ -1154,20 +1156,26 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout,
certificateValidationCallback = SslTrustServerValidation;
checkCertificateRevocation = false;
}
else if (((caCerts = DataSource.TransportSecurityHandler.RootCertificatesCallback?.Invoke()) is not null && caCerts.Count > 0) ||
(certRootPath = Settings.RootCertificate ??
PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is not null)
{
certificateValidationCallback = SslRootValidation(sslMode == SslMode.VerifyFull, certRootPath, caCerts);
}
else if (sslMode == SslMode.VerifyCA)
{
certificateValidationCallback = SslVerifyCAValidation;
}
else
{
Debug.Assert(sslMode == SslMode.VerifyFull);
certificateValidationCallback = SslVerifyFullValidation;
if (sslMode == SslMode.VerifyCA)
{
certificateValidationCallback = SslVerifyCAValidation;
}
else
{
Debug.Assert(sslMode == SslMode.VerifyFull);
certificateValidationCallback = SslVerifyFullValidation;
}

if (((caCerts = DataSource.TransportSecurityHandler.RootCertificatesCallback?.Invoke()) is not null && caCerts.Count > 0) ||
(certRootPath = Settings.RootCertificate ??
PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is not null)
{
// Do not use system certificates in addition to custom root certificates
// This is the exact same behavior as libpq
certificateChainPolicy = GetCustomCertificateChainPolicy(certRootPath, caCerts);
}
}

SslStreamCertificateContext? clientCertificateContext = null;
Expand All @@ -1193,6 +1201,7 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout,
var sslStreamOptions = new SslClientAuthenticationOptions
{
TargetHost = host,
CertificateChainPolicy = certificateChainPolicy,
ClientCertificateContext = clientCertificateContext,
EnabledSslProtocols = SslProtocols.None,
CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
Expand Down Expand Up @@ -1996,46 +2005,73 @@ internal void ClearTransaction(Exception? disposeReason = null)
(sender, certificate, chain, sslPolicyErrors)
=> true;

static RemoteCertificateValidationCallback SslRootValidation(bool verifyFull, string? certRootPath, X509Certificate2Collection? caCertificates)
=> (_, certificate, chain, sslPolicyErrors) =>
{
if (certificate is null || chain is null)
return false;
private static X509ChainPolicy GetCustomCertificateChainPolicy(string? certRootPath, X509Certificate2Collection? caCertificates)
{
var certs = GetCustomRootCertificates(certRootPath, caCertificates);

// Even if there was no error while validating, we have to check one more time with the provided certificate
// As this is the exact same behavior as libpq
var certificateChainPolicy = new X509ChainPolicy();

// That's VerifyFull check and we have name mismatch - no reason to check further
if (verifyFull && sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch))
return false;
certificateChainPolicy.CustomTrustStore.AddRange(certs);
certificateChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;

var certs = new X509Certificate2Collection();
certificateChainPolicy.ExtraStore.AddRange(certs);

if (certRootPath is null)
{
Debug.Assert(caCertificates is { Count: > 0 });
certs.AddRange(caCertificates);
}
else
{
Debug.Assert(caCertificates is null or { Count: > 0 });
if (Path.GetExtension(certRootPath).ToUpperInvariant() != ".PFX")
certs.ImportFromPemFile(certRootPath);
return certificateChainPolicy;
}

if (certs.Count == 0)
{
// This is not a PEM certificate, probably PFX
certs.Add(X509CertificateLoader.LoadPkcs12FromFile(certRootPath, null));
}
}
private static readonly ConcurrentDictionary<(string, DateTime?), X509Certificate2Collection> CustomRootCertificateCache = new();

private static X509Certificate2Collection GetCustomRootCertificates(string? certRootPath, X509Certificate2Collection? caCertificates)
{
if (certRootPath is null)
{
Debug.Assert(caCertificates is { Count: > 0 });
return caCertificates;
}
else
{
Debug.Assert(caCertificates is null or { Count: 0 });

chain.ChainPolicy.CustomTrustStore.AddRange(certs);
chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
// Add the file timestamp to the cache key, in case the certificate file is modified while
// the application is running.
// If this happens, a useless old entry will remain in the cache, but we don't really
// expect the file to change in the first place.
var certRootTimeStamp = TryGetFileTimeStamp(certRootPath);

chain.ChainPolicy.ExtraStore.AddRange(certs);
return CustomRootCertificateCache.GetOrAdd((certRootPath, certRootTimeStamp), certRoot =>
LoadRootCertificatesFromFile(certRoot.Item1));
}
}

return chain.Build(certificate as X509Certificate2 ?? new X509Certificate2(certificate));
};
private static X509Certificate2Collection LoadRootCertificatesFromFile(string certRootPath)
{
var certs = new X509Certificate2Collection();

if (Path.GetExtension(certRootPath).ToUpperInvariant() != ".PFX")
certs.ImportFromPemFile(certRootPath);

if (certs.Count == 0)
{
// This is not a PEM certificate, probably PFX
certs.Add(X509CertificateLoader.LoadPkcs12FromFile(certRootPath, null));
}

return certs;
}

private static DateTime? TryGetFileTimeStamp(string path)
{
try
{
return File.GetLastWriteTimeUtc(path);
}
catch
{
// Ignore errors at this point. If the file is loaded afterwards, the code that
// does that will hopefully throw a more meaningful exception.
return null;
}
}

#endregion SSL

Expand Down