@@ -4,12 +4,15 @@ import (
44 "context"
55 "database/sql"
66 "fmt"
7+ stdlog "log"
8+ "os"
79 "testing"
810
911 "github.com/google/uuid"
1012 _ "github.com/jackc/pgx/v4/stdlib"
1113 "github.com/kr/pretty"
12- "github.com/stretchr/testify/suite"
14+ "github.com/stretchr/testify/assert"
15+ "github.com/stretchr/testify/require"
1316 "github.com/stripe/pg-schema-diff/internal/pgdump"
1417 "github.com/stripe/pg-schema-diff/internal/pgengine"
1518 "github.com/stripe/pg-schema-diff/pkg/diff"
@@ -60,37 +63,41 @@ type (
6063 // If no expectedDBSchemaDDL is specified, the newSchemaDDL will be used
6164 expectedDBSchemaDDL []string
6265 }
63-
64- acceptanceTestSuite struct {
65- suite.Suite
66- pgEngine * pgengine.Engine
67- }
6866)
6967
70- func (suite * acceptanceTestSuite ) SetupSuite () {
71- engine , err := pgengine .StartEngine ()
72- suite .Require ().NoError (err )
73- suite .pgEngine = engine
74- }
68+ var pgEngine * pgengine.Engine
7569
76- func (suite * acceptanceTestSuite ) TearDownSuite () {
77- suite .pgEngine .Close ()
70+ func TestMain (m * testing.M ) {
71+ engine , err := pgengine .StartEngine ()
72+ if err != nil {
73+ stdlog .Fatalf ("Failed to start engine: %v" , err )
74+ }
75+ pgEngine = engine
76+ defer pgEngine .Close ()
77+ os .Exit (m .Run ())
7878}
7979
8080// Simulates migrating a database and uses pgdump to compare the actual state to the expected state
81- func (suite * acceptanceTestSuite ) runTestCases (acceptanceTestCases []acceptanceTestCase ) {
82- for _ , tc := range acceptanceTestCases {
83- suite .Run (tc .name , func () {
84- suite .runTest (tc )
81+ func runTestCases (t * testing.T , acceptanceTestCases []acceptanceTestCase ) {
82+ for _ , _tc := range acceptanceTestCases {
83+ // Copy the test case since we are using t.Parallel (effectively spinning out a go routine).
84+ tc := _tc
85+ t .Run (tc .name , func (t * testing.T ) {
86+ t .Parallel ()
87+ runTest (t , tc )
8588 })
8689 }
8790}
8891
89- func (suite * acceptanceTestSuite ) runTest (tc acceptanceTestCase ) {
90- uuid .SetRand (& deterministicRandReader {})
91-
92- // Normalize the subtest
93- tc .planOpts = append (tc .planOpts , diff .WithLogger (log .SimpleLogger ()))
92+ func runTest (t * testing.T , tc acceptanceTestCase ) {
93+ deterministicRandReader := & deterministicRandReader {}
94+ // We moved a call to the random when we made tests run in parallel. This caused assertions on exact statements to fail.
95+ // To keep the assertions passing, we will generate a UUID and throw it out. In the future, we should just create
96+ // a more advanced system for asserting random statements that captures variables and allows them to be referenced
97+ // in future assertions.
98+ _ , err := uuid .NewRandomFromReader (deterministicRandReader )
99+ require .NoError (t , err )
100+ tc .planOpts = append ([]diff.PlanOpt {diff .WithLogger (log .SimpleLogger ()), diff .WithRandReader (deterministicRandReader )}, tc .planOpts ... )
94101 if tc .expectedDBSchemaDDL == nil {
95102 tc .expectedDBSchemaDDL = tc .newSchemaDDL
96103 }
@@ -106,70 +113,77 @@ func (suite *acceptanceTestSuite) runTest(tc acceptanceTestCase) {
106113 }
107114 }
108115
116+ engine := pgEngine
117+ if len (tc .roles ) > 0 {
118+ // If the test needs roles (server-wide), provide isolation by spinning out a dedicated pgengine.
119+ dedicatedEngine , err := pgengine .StartEngine ()
120+ require .NoError (t , err )
121+ defer dedicatedEngine .Close ()
122+ engine = dedicatedEngine
123+ }
124+
109125 // Create roles since they are global
110- rootDb , err := sql .Open ("pgx" , suite . pgEngine .GetPostgresDatabaseDSN ())
111- suite . Require (). NoError (err )
126+ rootDb , err := sql .Open ("pgx" , engine .GetPostgresDatabaseDSN ())
127+ require . NoError (t , err )
112128 defer rootDb .Close ()
113129 for _ , r := range tc .roles {
114130 _ , err := rootDb .Exec (fmt .Sprintf ("CREATE ROLE %s" , r ))
115- suite . Require (). NoError (err )
131+ require . NoError (t , err )
116132 }
117- defer func () {
118- // This will drop the roles (and attempt to reset other cluster-level state)
119- suite .Require ().NoError (pgengine .ResetInstance (context .Background (), rootDb ))
120- }()
121133
122134 // Apply old schema DDL to old DB
123- oldDb , err := suite .pgEngine .CreateDatabase ()
124- suite .Require ().NoError (err )
135+ require .NoError (t , err )
136+ oldDb , err := engine .CreateDatabaseWithName (fmt .Sprintf ("pgtemp_%s" , uuid .NewString ()))
137+ require .NoError (t , err )
125138 defer oldDb .DropDB ()
126139 // Apply the old schema
127- suite . Require (). NoError (applyDDL (oldDb , tc .oldSchemaDDL ))
140+ require . NoError (t , applyDDL (oldDb , tc .oldSchemaDDL ))
128141
129142 // Migrate the old DB
130143 oldDBConnPool , err := sql .Open ("pgx" , oldDb .GetDSN ())
131- suite . Require (). NoError (err )
144+ require . NoError (t , err )
132145 defer oldDBConnPool .Close ()
146+ oldDBConnPool .SetMaxOpenConns (1 )
133147
134148 tempDbFactory , err := tempdb .NewOnInstanceFactory (context .Background (), func (ctx context.Context , dbName string ) (* sql.DB , error ) {
135- return sql .Open ("pgx" , suite . pgEngine .GetPostgresDatabaseConnOpts ().With ("dbname" , dbName ).ToDSN ())
136- })
137- suite . Require (). NoError (err )
149+ return sql .Open ("pgx" , engine .GetPostgresDatabaseConnOpts ().With ("dbname" , dbName ).ToDSN ())
150+ }, tempdb . WithRandReader ( deterministicRandReader ) )
151+ require . NoError (t , err )
138152 defer func (tempDbFactory tempdb.Factory ) {
139153 // It's important that this closes properly (the temp database is dropped),
140154 // so assert it has no error for acceptance tests
141- suite . Require (). NoError (tempDbFactory .Close ())
155+ require . NoError (t , tempDbFactory .Close ())
142156 }(tempDbFactory )
143157
144158 plan , err := tc .planFactory (context .Background (), oldDBConnPool , tempDbFactory , tc .newSchemaDDL , tc .planOpts ... )
145159 if tc .expectedPlanErrorIs != nil || len (tc .expectedPlanErrorContains ) > 0 {
146160 if tc .expectedPlanErrorIs != nil {
147- suite .ErrorIs (err , tc .expectedPlanErrorIs )
161+ assert .ErrorIs (t , err , tc .expectedPlanErrorIs )
148162 }
149163 if len (tc .expectedPlanErrorContains ) > 0 {
150- suite .ErrorContains (err , tc .expectedPlanErrorContains )
164+ assert .ErrorContains (t , err , tc .expectedPlanErrorContains )
151165 }
152166 return
153167 }
154- suite . Require (). NoError (err )
168+ require . NoError (t , err )
155169
156- suite . assertValidPlan (plan )
170+ assertValidPlan (t , plan )
157171 if tc .expectEmptyPlan {
158172 // It shouldn't be necessary, but we'll run all checks below this point just in case rather than exiting early
159- suite .Empty (plan .Statements )
173+ assert .Empty (t , plan .Statements )
160174 }
161- suite .ElementsMatch (tc .expectedHazardTypes , getUniqueHazardTypesFromStatements (plan .Statements ), prettySprintPlan (plan ))
175+ assert .ElementsMatch (t , tc .expectedHazardTypes , getUniqueHazardTypesFromStatements (plan .Statements ), prettySprintPlan (plan ))
162176
163177 // Apply the plan
164- suite . Require (). NoError (applyPlan (oldDb , plan ), prettySprintPlan (plan ))
178+ require . NoError (t , applyPlan (oldDb , plan ), prettySprintPlan (plan ))
165179
166180 // Make sure the pgdump after running the migration is the same as the
167181 // pgdump from a database where we directly run the newSchemaDDL
168182 oldDbDump , err := pgdump .GetDump (oldDb , pgdump .WithSchemaOnly ())
169- suite . Require (). NoError (err )
183+ require . NoError (t , err )
170184
171- newDbDump := suite . directlyRunDDLAndGetDump (tc .expectedDBSchemaDDL )
172- suite .Equal (newDbDump , oldDbDump , prettySprintPlan (plan ))
185+ newDbDump := directlyRunDDLAndGetDump (t , engine , tc .expectedDBSchemaDDL )
186+ assert .Equal (t , newDbDump , oldDbDump , prettySprintPlan (plan ))
173187
174188 if tc .expectedPlanDDL != nil {
175189 var generatedDDL []string
@@ -181,30 +195,30 @@ func (suite *acceptanceTestSuite) runTest(tc acceptanceTestCase) {
181195 // We can also make the system more advanced by using tokens in place of the "randomly" generated UUIDs, such
182196 // the test case doesn't need to be updated if the UUID generation changes. If we built this functionality, we
183197 // should also integrate it with the schema_migration_plan_test.go tests.
184- suite .Equal (tc .expectedPlanDDL , generatedDDL , "data packing can change the the generated UUID and DDL" )
198+ assert .Equal (t , tc .expectedPlanDDL , generatedDDL , "data packing can change the the generated UUID and DDL" )
185199 }
186200
187201 // Make sure no diff is found if we try to regenerate a plan
188202 plan , err = tc .planFactory (context .Background (), oldDBConnPool , tempDbFactory , tc .newSchemaDDL , tc .planOpts ... )
189- suite . Require (). NoError (err )
190- suite .Empty (plan .Statements , prettySprintPlan (plan ))
203+ require . NoError (t , err )
204+ assert .Empty (t , plan .Statements , prettySprintPlan (plan ))
191205}
192206
193- func ( suite * acceptanceTestSuite ) assertValidPlan ( plan diff.Plan ) {
207+ func assertValidPlan ( t * testing. T , plan diff.Plan ) {
194208 for _ , stmt := range plan .Statements {
195- suite .Greater (stmt .Timeout .Nanoseconds (), int64 (0 ), "timeout should be greater than 0. stmt=%+v" , stmt )
196- suite .Greater (stmt .LockTimeout .Nanoseconds (), int64 (0 ), "lock timeout should be greater than 0. stmt=%+v" , stmt )
209+ assert .Greater (t , stmt .Timeout .Nanoseconds (), int64 (0 ), "timeout should be greater than 0. stmt=%+v" , stmt )
210+ assert .Greater (t , stmt .LockTimeout .Nanoseconds (), int64 (0 ), "lock timeout should be greater than 0. stmt=%+v" , stmt )
197211 }
198212}
199213
200- func ( suite * acceptanceTestSuite ) directlyRunDDLAndGetDump ( ddl []string ) string {
201- newDb , err := suite . pgEngine .CreateDatabase ()
202- suite . Require (). NoError (err )
214+ func directlyRunDDLAndGetDump ( t * testing. T , engine * pgengine. Engine , ddl []string ) string {
215+ newDb , err := engine .CreateDatabase ()
216+ require . NoError (t , err )
203217 defer newDb .DropDB ()
204- suite . Require (). NoError (applyDDL (newDb , ddl ))
218+ require . NoError (t , applyDDL (newDb , ddl ))
205219
206220 newDbDump , err := pgdump .GetDump (newDb , pgdump .WithSchemaOnly ())
207- suite . Require (). NoError (err )
221+ require . NoError (t , err )
208222 return newDbDump
209223}
210224
@@ -261,7 +275,3 @@ func (r *deterministicRandReader) Read(p []byte) (int, error) {
261275 }
262276 return len (p ), nil
263277}
264-
265- func TestAcceptanceSuite (t * testing.T ) {
266- suite .Run (t , new (acceptanceTestSuite ))
267- }
0 commit comments