-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconstructsql.go
More file actions
92 lines (76 loc) · 2.01 KB
/
constructsql.go
File metadata and controls
92 lines (76 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package constructsql
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/networkteam/construct/v2"
)
// CollectRows collects all rows to the given target type from a ExecutiveQueryBuilder.Query result.
func CollectRows[T any](rows Rows, queryErr error) (result []T, err error) {
if queryErr != nil {
return nil, queryErr
}
defer func() {
err = errors.Join(err, rows.Close())
}()
slice := []T{}
for rows.Next() {
value, err := scanRow[T](rows)
if err != nil {
return nil, err
}
slice = append(slice, value)
}
if err := rows.Err(); err != nil {
return nil, err
}
return slice, nil
}
type Rows interface {
RowScanner
Next() bool
Close() error
Err() error
}
type RowScanner interface {
Scan(dest ...any) error
}
func sqlToConstructErr(err error) error {
if errors.Is(err, sql.ErrNoRows) {
return construct.ErrNotFound
}
return err
}
// ScanRow scans a single row to the given target type from a ExecutiveQueryBuilder.QueryRow result.
// It expects a single JSON column to be selected and unmarshals to the given struct type.
func ScanRow[T any](row RowScanner, err error) (T, error) {
var result T
if err != nil {
return result, sqlToConstructErr(err)
}
return scanRow[T](row)
}
func scanRow[T any](row RowScanner) (result T, err error) {
var data []byte
if err := row.Scan(&data); err != nil {
return result, fmt.Errorf("scanning row: %w", sqlToConstructErr(err))
}
return result, json.Unmarshal(data, &result)
}
// AssertRowsAffected checks if the given result affected exactly the expected number of rows.
func AssertRowsAffected(operation string, expectedRows int) func(sql.Result, error) error {
return func(result sql.Result, err error) error {
if err != nil {
return err
}
actualRows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("getting affected rows: %w", err)
}
if actualRows != int64(expectedRows) {
return fmt.Errorf("%s affected %d rows, but expected exactly %d", operation, actualRows, expectedRows)
}
return nil
}
}