Skip to content

Commit e3ce816

Browse files
authored
Expiry mailer: fetch certificates in bulk (letsencrypt#5607)
Use `sa.SelectCertificates` instead of `sa.SelectCertificate` to fetch the entire batch of certificates all at once, instead of doing up to 10k individual certificate selections in serial.
1 parent 6dff9c5 commit e3ce816

File tree

3 files changed

+38
-26
lines changed

3 files changed

+38
-26
lines changed

cmd/expiration-mailer/main.go

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
netmail "net/mail"
1414
"net/url"
1515
"os"
16+
"regexp"
1617
"sort"
1718
"strings"
1819
"text/template"
@@ -334,24 +335,34 @@ func (m *mailer) findExpiringCertificates() error {
334335
}
335336
m.stats.nagsAtCapacity.With(prometheus.Labels{"nag_group": expiresIn.String()}).Set(atCapacity)
336337

337-
// Now we can sequentially retrieve the certificate details for each of the
338-
// certificate status rows
339-
var certs []core.Certificate
340-
for _, serial := range serials {
341-
var cert core.Certificate
342-
cert, err := sa.SelectCertificate(m.dbMap, serial)
343-
if err != nil {
344-
// We can get a NoRowsErr when processing a serial number corresponding
345-
// to a precertificate with no final certificate. Since this certificate
346-
// is not being used by a subscriber, we don't send expiration email about
347-
// it.
348-
if db.IsNoRows(err) {
349-
continue
350-
}
351-
m.log.AuditErrf("expiration-mailer: Error loading cert %q: %s", cert.Serial, err)
352-
return err
338+
if len(serials) == 0 {
339+
continue // nothing to do
340+
}
341+
342+
// Wrap every serial in quotes so they can be interpolated into the query.
343+
quotedSerials := make([]string, len(serials))
344+
serialRegexp := regexp.MustCompile("^[0-9a-f]+$")
345+
for i, s := range serials {
346+
if !serialRegexp.MatchString(s) {
347+
return fmt.Errorf("encountered malformed serial %q", s)
353348
}
354-
certs = append(certs, cert)
349+
quotedSerials[i] = fmt.Sprintf("'%s'", s)
350+
}
351+
352+
// Now we can retrieve the certificate details for all of the status rows.
353+
certWithIDs, err := sa.SelectCertificates(
354+
m.dbMap,
355+
fmt.Sprintf("WHERE serial IN (%s)", strings.Join(quotedSerials, ",")),
356+
nil,
357+
)
358+
if err != nil {
359+
m.log.AuditErrf("expiration-mailer: error retrieving certs: %s", err)
360+
return err
361+
}
362+
363+
certs := make([]core.Certificate, len(certWithIDs))
364+
for i, c := range certWithIDs {
365+
certs[i] = c.Certificate
355366
}
356367

357368
m.log.Infof("Found %d certificates expiring between %s and %s", len(certs),

cmd/expiration-mailer/main_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -724,12 +724,12 @@ func TestDedupOnRegistration(t *testing.T) {
724724
}
725725
regA, err = testCtx.ssa.NewRegistration(ctx, regA)
726726
test.AssertNotError(t, err, "Couldn't store regA")
727+
727728
rawCertA := newX509Cert("happy A",
728729
testCtx.fc.Now().Add(72*time.Hour),
729730
[]string{"example-a.com", "shared-example.com"},
730731
serial1,
731732
)
732-
733733
certDerA, _ := x509.CreateCertificate(rand.Reader, rawCertA, rawCertA, &testKey.PublicKey, &testKey)
734734
certA := &core.Certificate{
735735
RegistrationID: regA.Id,
@@ -764,11 +764,8 @@ func TestDedupOnRegistration(t *testing.T) {
764764

765765
err = testCtx.m.findExpiringCertificates()
766766
test.AssertNotError(t, err, "error calling findExpiringCertificates")
767-
if len(testCtx.mc.Messages) > 1 {
768-
t.Errorf("num of messages, want %d, got %d", 1, len(testCtx.mc.Messages))
769-
}
770-
if len(testCtx.mc.Messages) == 0 {
771-
t.Fatalf("no messages sent")
767+
if len(testCtx.mc.Messages) != 1 {
768+
t.Fatalf("wrong num of messages, want 1, got %d", len(testCtx.mc.Messages))
772769
}
773770
domains := "example-a.com\nexample-b.com\nshared-example.com"
774771
expected := mocks.MailerMessage{

test/v1_integration.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,13 @@ def test_expiration_mailer():
402402

403403
requests.post("http://localhost:9381/clear", data='')
404404
for time in (no_reminder, first_reminder, last_reminder):
405-
print(get_future_output(
406-
["./bin/expiration-mailer", "--config", "%s/expiration-mailer.json" % config_dir],
407-
time))
405+
try:
406+
print(get_future_output(
407+
["./bin/expiration-mailer", "--config", "%s/expiration-mailer.json" % config_dir],
408+
time))
409+
except subprocess.CalledProcessError as e:
410+
print(e.output.decode("unicode-escape"))
411+
raise
408412
resp = requests.get("http://localhost:9381/count?to=%s" % email_addr)
409413
mailcount = int(resp.text)
410414
if mailcount != 2:

0 commit comments

Comments
 (0)