-
Notifications
You must be signed in to change notification settings - Fork 0
/
postgrestest.go
190 lines (170 loc) · 6.06 KB
/
postgrestest.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
package postgrestest
import (
"crypto/rand"
"database/sql"
"fmt"
mathrand "math/rand"
"net/url"
"os"
"strings"
_ "github.com/jackc/pgx/v4/stdlib" // postgres driver
"github.com/stretchr/testify/require"
)
// ConnectFunction is the signature of a function used to open database connections.
type ConnectFunction func(address string) (*sql.DB, error)
// CreateDatabaseFunction is the signature of function used to create the database.
type CreateDatabaseFunction func(db *sql.DB, database string) error
// DeleteDatabaseFunction is the signature of function used to delete the database.
type DeleteDatabaseFunction func(db *sql.DB, database string) error
// DefaultConnectFunction is the default function used to open database connections.
func DefaultConnectFunction(address string) (*sql.DB, error) {
return sql.Open("pgx", address)
}
// DefaultCreateDatabaseFunction is the default function used to create instances.
func DefaultCreateDatabaseFunction(db *sql.DB, database string) error {
_, err := db.Exec(`CREATE DATABASE ` + database)
return err
}
// DefaultDeleteDatabaseFunction is the default function used to delete instances.
func DefaultDeleteDatabaseFunction(db *sql.DB, database string) error {
_, err := db.Exec(`DROP DATABASE ` + database)
return err
}
// ForceDeleteDatabaseFunction is a function used to delete instances with force.
func ForceDeleteDatabaseFunction(db *sql.DB, database string) error {
_, err := db.Exec(`DROP DATABASE ` + database + ` WITH (FORCE);`)
return err
}
// Option is the signature of options that can be provided to NewPostgresTest.
type Option func(opts *options)
// WithBaseAddress is an option that allows providing the address
// for the base database.
func WithBaseAddress(address string) Option {
return func(opts *options) {
opts.baseAddress = address
}
}
// WithConnectFunction is an option that allows providing the connection
// function to be used with NewPostgresTest.
func WithConnectFunction(connectFunction ConnectFunction) Option {
return func(opts *options) {
opts.connectFunction = connectFunction
}
}
// WithCreateDatabaseFunction is an option that allows providing the create
// database function to be used with NewPostgresTest.
func WithCreateDatabaseFunction(createDatabaseFunction CreateDatabaseFunction) Option {
return func(opts *options) {
opts.createDatabaseFunction = createDatabaseFunction
}
}
// WithDeleteDatabaseFunction is an option that allows providing the create
// database function to be used with NewPostgresTest.
func WithDeleteDatabaseFunction(deleteDatabaseFunction DeleteDatabaseFunction) Option {
return func(opts *options) {
opts.deleteDatabaseFunction = deleteDatabaseFunction
}
}
// options holds references for all the options we allow proving on NewPostgresTest.
type options struct {
baseAddress string
connectFunction ConnectFunction
createDatabaseFunction CreateDatabaseFunction
deleteDatabaseFunction DeleteDatabaseFunction
}
type TestingT interface {
Errorf(format string, args ...interface{})
FailNow()
Cleanup(func())
}
// NewPostgresTest returns a database DSN for connecting to a test database.
// It will create a database on the base testing Postgres server.
// It's possible to provide a DSN on the TESTING_POSTGRES_TEST environmental variable.
// If no value is present on the TESTING_POSTGRES_TEST environment variable,
// we try to use the default postgres://postgres:root@localhost:65432 as the base
// Postgres server.
func NewPostgresTest(t TestingT, opts ...Option) string {
if h, ok := t.(interface {
Helper()
}); ok {
h.Helper()
}
defaultOpts := &options{
baseAddress: os.Getenv("TESTING_POSTGRES_TEST"),
connectFunction: DefaultConnectFunction,
createDatabaseFunction: DefaultCreateDatabaseFunction,
deleteDatabaseFunction: DefaultDeleteDatabaseFunction,
}
for _, opt := range opts {
opt(defaultOpts)
}
if defaultOpts.baseAddress == "" {
defaultOpts.baseAddress = "postgres://postgres:root@localhost:55432"
}
// connect to the base database and create the test database
globalDB, err := sql.Open("pgx", defaultOpts.baseAddress)
require.NoError(t, err)
databaseName := createTestingDatabase(t, defaultOpts.createDatabaseFunction, globalDB, defaultOpts.baseAddress)
_ = globalDB.Close()
t.Cleanup(func() {
if defaultOpts.deleteDatabaseFunction == nil {
return
}
globalDB, err := sql.Open("pgx", defaultOpts.baseAddress)
require.NoError(t, err)
deleteDatabase(t, defaultOpts.deleteDatabaseFunction, globalDB, databaseName)
_ = globalDB.Close()
})
u, err := url.Parse(defaultOpts.baseAddress)
require.NoError(t, err)
u.Path = databaseName
return u.String()
}
// AlterTableSequences alters the table sequences to random numbers.
// This can be used to help find cases where a bug is introduced
// because integration tests use a fresh database and sequence numbers are
// very close to each other in all tables.
func AlterTableSequences(t TestingT, db *sql.DB) {
if h, ok := t.(interface {
Helper()
}); ok {
h.Helper()
}
rows, err := db.Query(`SELECT c.relname FROM pg_class c WHERE c.relkind = 'S';`)
require.NoError(t, err)
defer rows.Close()
var sequences []string
for rows.Next() {
var sequence string
err := rows.Scan(&sequence)
require.NoError(t, err)
sequences = append(sequences, sequence)
}
for _, seq := range sequences {
_, err := db.Exec(fmt.Sprintf("ALTER SEQUENCE %s RESTART WITH %d;", seq, mathrand.Intn(100000)+100)) //nolint:gosec
require.NoError(t, err)
}
}
func createTestingDatabase(t TestingT, createDatabase CreateDatabaseFunction, db *sql.DB, addr string) string {
if h, ok := t.(interface {
Helper()
}); ok {
h.Helper()
}
b := make([]byte, 8)
_, err := rand.Read(b) //nolint:gosec
require.NoError(t, err)
database := strings.ToLower(fmt.Sprintf("testing_db_%x", b))
err = createDatabase(db, database)
require.NoError(t, err)
return database
}
func deleteDatabase(t TestingT, deleteDatabase DeleteDatabaseFunction, db *sql.DB, databaseName string) {
if h, ok := t.(interface {
Helper()
}); ok {
h.Helper()
}
err := deleteDatabase(db, databaseName)
require.NoError(t, err)
}