From 0a485ddd6feaa5d31422bbebbee09e6c746c264b Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Wed, 8 May 2024 11:02:39 -0700 Subject: [PATCH 1/5] It's all broken --- internal/compiler/compile.go | 4 + internal/compiler/engine.go | 3 +- internal/config/config.go | 6 ++ internal/pgx/createdb/createdb.go | 114 ++++++++++++++++++++++++++++ internal/pgx/poolcache/poolcache.go | 27 ++++--- internal/sqltest/local/postgres.go | 3 +- 6 files changed, 146 insertions(+), 11 deletions(-) create mode 100644 internal/pgx/createdb/createdb.go diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 84fbb20a3c..7699e74f33 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -3,6 +3,7 @@ package compiler import ( "errors" "fmt" + "hash/fnv" "io" "os" "path/filepath" @@ -31,6 +32,7 @@ func (c *Compiler) parseCatalog(schemas []string) error { return err } merr := multierr.New() + h := fnv.New64() for _, filename := range files { blob, err := os.ReadFile(filename) if err != nil { @@ -38,6 +40,7 @@ func (c *Compiler) parseCatalog(schemas []string) error { continue } contents := migrations.RemoveRollbackStatements(string(blob)) + io.WriteString(h, contents) c.schema = append(c.schema, contents) stmts, err := c.parser.Parse(strings.NewReader(contents)) if err != nil { @@ -51,6 +54,7 @@ func (c *Compiler) parseCatalog(schemas []string) error { } } } + c.schemaHash = fmt.Sprintf("%x", h.Sum(nil)) if len(merr.Errs()) > 0 { return merr } diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index e7e11152c4..9feb9c4f8a 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -25,7 +25,8 @@ type Compiler struct { analyzer analyzer.Analyzer client pb.QuickClient - schema []string + schema []string + schemaHash string } func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, error) { diff --git a/internal/config/config.go b/internal/config/config.go index 7decfe3698..f37a7a8da6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -59,6 +59,7 @@ const ( type Config struct { Version string `json:"version" yaml:"version"` Cloud Cloud `json:"cloud" yaml:"cloud"` + Servers []Server `json:"servers" yaml:"servers"` SQL []SQL `json:"sql" yaml:"sql"` Overrides Overrides `json:"overrides,omitempty" yaml:"overrides"` Plugins []Plugin `json:"plugins" yaml:"plugins"` @@ -78,6 +79,11 @@ type Cloud struct { AuthToken string `json:"-" yaml:"-"` } +type Server struct { + Name string `json:"name" yaml:"name"` + URI string `json:"uri" yaml:"uri"` +} + type Plugin struct { Name string `json:"name" yaml:"name"` Env []string `json:"env" yaml:"env"` diff --git a/internal/pgx/createdb/createdb.go b/internal/pgx/createdb/createdb.go new file mode 100644 index 0000000000..4690f9b0f9 --- /dev/null +++ b/internal/pgx/createdb/createdb.go @@ -0,0 +1,114 @@ +package createdb + +import ( + "context" + "fmt" + "hash/fnv" + "net/url" + "os" + "strings" + + "github.com/jackc/pgx/v4/pgxpool" + "github.com/jackc/pgx/v5" + "golang.org/x/sync/singleflight" + + migrate "github.com/sqlc-dev/sqlc/internal/migrations" + "github.com/sqlc-dev/sqlc/internal/pgx/poolcache" + "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" +) + +var flight singleflight.Group +var cache = poolcache.New() + +type Server struct { + pool *pgxpool.Pool +} + +func Create(ctx context.Context, srv *pgxpool.Pool, migrations []string) (*pgxpool.Pool, error) { + ctx := context.Background() + + dburi := os.Getenv("POSTGRESQL_SERVER_URI") + if dburi == "" { + t.Skip("POSTGRESQL_SERVER_URI is empty") + } + + postgresPool, err := cache.Open(ctx, dburi) + if err != nil { + t.Fatalf("PostgreSQL pool creation failed: %s", err) + } + + var seed []string + files, err := sqlpath.Glob(migrations) + if err != nil { + t.Fatal(err) + } + + h := fnv.New64() + for _, f := range files { + blob, err := os.ReadFile(f) + if err != nil { + t.Fatal(err) + } + h.Write(blob) + seed = append(seed, migrate.RemoveRollbackStatements(string(blob))) + } + + var name string + if rw { + name = fmt.Sprintf("sqlc_test_%s", id()) + } else { + name = fmt.Sprintf("sqlc_test_%x", h.Sum(nil)) + } + + uri, err := url.Parse(dburi) + if err != nil { + t.Fatal(err) + } + uri.Path = name + dropQuery := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s" WITH (FORCE)`, name) + + key := uri.String() + + _, err, _ = flight.Do(key, func() (interface{}, error) { + row := postgresPool.QueryRow(ctx, + fmt.Sprintf(`SELECT datname FROM pg_database WHERE datname = '%s'`, name)) + + var datname string + if err := row.Scan(&datname); err == nil { + t.Logf("database exists: %s", name) + return nil, nil + } + + t.Logf("creating database: %s", name) + if _, err := postgresPool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil { + return nil, err + } + + conn, err := pgx.Connect(ctx, uri.String()) + if err != nil { + return nil, fmt.Errorf("connect %s: %s", name, err) + } + defer conn.Close(ctx) + + for _, q := range seed { + if len(strings.TrimSpace(q)) == 0 { + continue + } + if _, err := conn.Exec(ctx, q); err != nil { + return nil, fmt.Errorf("%s: %s", q, err) + } + } + return nil, nil + }) + if rw || err != nil { + t.Cleanup(func() { + if _, err := postgresPool.Exec(ctx, dropQuery); err != nil { + t.Fatalf("failed cleaning up: %s", err) + } + }) + } + if err != nil { + t.Fatalf("create db: %s", err) + } + return key +} diff --git a/internal/pgx/poolcache/poolcache.go b/internal/pgx/poolcache/poolcache.go index 93401ec936..7293c2202e 100644 --- a/internal/pgx/poolcache/poolcache.go +++ b/internal/pgx/poolcache/poolcache.go @@ -7,13 +7,22 @@ import ( "github.com/jackc/pgx/v5/pgxpool" ) -var lock sync.RWMutex -var pools = map[string]*pgxpool.Pool{} +type Cache struct { + lock sync.RWMutex + pools map[string]*pgxpool.Pool +} + +func New() *Cache { + return &Cache{ + pools: map[string]*pgxpool.Pool{}, + } +} -func New(ctx context.Context, uri string) (*pgxpool.Pool, error) { - lock.RLock() - existing, found := pools[uri] - lock.RUnlock() +// Should only be used in testing contexts +func (c *Cache) Open(ctx context.Context, uri string) (*pgxpool.Pool, error) { + c.lock.RLock() + existing, found := c.pools[uri] + c.lock.RUnlock() if found { return existing, nil @@ -24,9 +33,9 @@ func New(ctx context.Context, uri string) (*pgxpool.Pool, error) { return nil, err } - lock.Lock() - pools[uri] = pool - lock.Unlock() + c.lock.Lock() + c.pools[uri] = pool + c.lock.Unlock() return pool, nil } diff --git a/internal/sqltest/local/postgres.go b/internal/sqltest/local/postgres.go index 3520d42b82..84a0c3d956 100644 --- a/internal/sqltest/local/postgres.go +++ b/internal/sqltest/local/postgres.go @@ -18,6 +18,7 @@ import ( ) var flight singleflight.Group +var cache = poolcache.New() func PostgreSQL(t *testing.T, migrations []string) string { return postgreSQL(t, migrations, true) @@ -36,7 +37,7 @@ func postgreSQL(t *testing.T, migrations []string, rw bool) string { t.Skip("POSTGRESQL_SERVER_URI is empty") } - postgresPool, err := poolcache.New(ctx, dburi) + postgresPool, err := cache.Open(ctx, dburi) if err != nil { t.Fatalf("PostgreSQL pool creation failed: %s", err) } From 2ee6b81be714f5e3fbfb1b60d3a1e5ccd680d209 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 10 May 2024 09:53:12 -0700 Subject: [PATCH 2/5] WIP --- internal/cmd/generate.go | 25 ++++++ internal/compiler/compile.go | 4 - internal/compiler/engine.go | 13 ++- internal/config/config.go | 1 + .../engine/postgresql/analyzer/analyze.go | 50 ++++++++--- .../engine/postgresql/analyzer/createdb.go | 68 +++++++++++++++ internal/pgx/createdb/createdb.go | 87 ++++++------------- 7 files changed, 167 insertions(+), 81 deletions(-) create mode 100644 internal/engine/postgresql/analyzer/createdb.go diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index a7e64e1e46..78ec757917 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log" "os" "path/filepath" "runtime/trace" @@ -27,6 +28,8 @@ import ( "github.com/sqlc-dev/sqlc/internal/info" "github.com/sqlc-dev/sqlc/internal/multierr" "github.com/sqlc-dev/sqlc/internal/opts" + "github.com/sqlc-dev/sqlc/internal/pgx/createdb" + "github.com/sqlc-dev/sqlc/internal/pgx/poolcache" "github.com/sqlc-dev/sqlc/internal/plugin" "github.com/sqlc-dev/sqlc/internal/remote" "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" @@ -316,9 +319,31 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C } return nil, true } + + { + uri := combo.Global.Servers[0].URI + cache := poolcache.New() + pool, err := cache.Open(ctx, uri) + if err != nil { + log.Println("cache.Open", err) + return nil, false + } + creator := createdb.New(uri, pool) + dburi, db, err := creator.Create(ctx, c.SchemaHash, c.Schema) + if err != nil { + log.Println("creator.Create", err) + } + fmt.Println(db) + + combo.Package.Database.URI = dburi + combo.Package.Database.Managed = false + c.UpdateAnalyzer(combo.Package.Database) + } + if parserOpts.Debug.DumpCatalog { debug.Dump(c.Catalog()) } + if err := c.ParseQueries(sql.Queries, parserOpts); err != nil { fmt.Fprintf(stderr, "# package %s\n", name) if parserErr, ok := err.(*multierr.Error); ok { diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 7699e74f33..84fbb20a3c 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -3,7 +3,6 @@ package compiler import ( "errors" "fmt" - "hash/fnv" "io" "os" "path/filepath" @@ -32,7 +31,6 @@ func (c *Compiler) parseCatalog(schemas []string) error { return err } merr := multierr.New() - h := fnv.New64() for _, filename := range files { blob, err := os.ReadFile(filename) if err != nil { @@ -40,7 +38,6 @@ func (c *Compiler) parseCatalog(schemas []string) error { continue } contents := migrations.RemoveRollbackStatements(string(blob)) - io.WriteString(h, contents) c.schema = append(c.schema, contents) stmts, err := c.parser.Parse(strings.NewReader(contents)) if err != nil { @@ -54,7 +51,6 @@ func (c *Compiler) parseCatalog(schemas []string) error { } } } - c.schemaHash = fmt.Sprintf("%x", h.Sum(nil)) if len(merr.Errs()) > 0 { return merr } diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index 9feb9c4f8a..9f80f6643e 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -25,8 +25,7 @@ type Compiler struct { analyzer analyzer.Analyzer client pb.QuickClient - schema []string - schemaHash string + schema []string } func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, error) { @@ -53,7 +52,7 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err if conf.Database != nil { if conf.Analyzer.Database == nil || *conf.Analyzer.Database { c.analyzer = analyzer.Cached( - pganalyze.New(c.client, *conf.Database), + pganalyze.New(c.client, combo.Global.Servers, *conf.Database), combo.Global, *conf.Database, ) @@ -65,6 +64,14 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err return c, nil } +func (c *Compiler) UpdateAnalyzer(db *config.Database) { + c.analyzer = analyzer.Cached( + pganalyze.New(c.client, *db), + c.combo.Global, + *db, + ) +} + func (c *Compiler) Catalog() *catalog.Catalog { return c.catalog } diff --git a/internal/config/config.go b/internal/config/config.go index f37a7a8da6..5074a4d3b1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -70,6 +70,7 @@ type Config struct { type Database struct { URI string `json:"uri" yaml:"uri"` Managed bool `json:"managed" yaml:"managed"` + Auto bool `json:"auto" yaml:"auto"` } type Cloud struct { diff --git a/internal/engine/postgresql/analyzer/analyze.go b/internal/engine/postgresql/analyzer/analyze.go index be19fcf539..904f43985c 100644 --- a/internal/engine/postgresql/analyzer/analyze.go +++ b/internal/engine/postgresql/analyzer/analyze.go @@ -4,16 +4,20 @@ import ( "context" "errors" "fmt" + "hash/fnv" + "io" "strings" "sync" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" + "golang.org/x/sync/singleflight" core "github.com/sqlc-dev/sqlc/internal/analysis" "github.com/sqlc-dev/sqlc/internal/config" "github.com/sqlc-dev/sqlc/internal/opts" + "github.com/sqlc-dev/sqlc/internal/pgx/poolcache" pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1" "github.com/sqlc-dev/sqlc/internal/shfmt" "github.com/sqlc-dev/sqlc/internal/sql/ast" @@ -22,22 +26,28 @@ import ( ) type Analyzer struct { - db config.Database - client pb.QuickClient - pool *pgxpool.Pool - dbg opts.Debug - replacer *shfmt.Replacer - formats sync.Map - columns sync.Map - tables sync.Map + db config.Database + client pb.QuickClient + pool *pgxpool.Pool + dbg opts.Debug + replacer *shfmt.Replacer + formats sync.Map + columns sync.Map + tables sync.Map + servers []config.Server + serverCache *poolcache.Cache + flight singleflight.Group } -func New(client pb.QuickClient, db config.Database) *Analyzer { +func New(client pb.QuickClient, servers []config.Server, db config.Database) *Analyzer { return &Analyzer{ - db: db, - dbg: opts.DebugFromEnv(), - client: client, - replacer: shfmt.NewReplacer(nil), + // TODO: Pick first + servers: servers, + db: db, + dbg: opts.DebugFromEnv(), + client: client, + replacer: shfmt.NewReplacer(nil), + serverCache: poolcache.New(), } } @@ -99,6 +109,14 @@ type columnKey struct { Attr uint16 } +func (a *Analyzer) fnv(migrations []string) string { + h := fnv.New64() + for _, query := range migrations { + io.WriteString(h, query) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + // Cache these types in memory func (a *Analyzer) columnInfo(ctx context.Context, field pgconn.FieldDescription) (*pgColumn, error) { key := columnKey{field.TableOID, field.TableAttributeNumber} @@ -211,6 +229,12 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat uri = edb.Uri } else if a.dbg.OnlyManagedDatabases { return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed") + } else if a.db.Auto { + var err error + uri, err = a.createDb(ctx, migrations) + if err != nil { + return nil, err + } } else { uri = a.replacer.Replace(a.db.URI) } diff --git a/internal/engine/postgresql/analyzer/createdb.go b/internal/engine/postgresql/analyzer/createdb.go new file mode 100644 index 0000000000..1b6ba49709 --- /dev/null +++ b/internal/engine/postgresql/analyzer/createdb.go @@ -0,0 +1,68 @@ +package analyzer + +import ( + "context" + "fmt" + "log/slog" + "net/url" + "strings" + + "github.com/jackc/pgx/v5" +) + +func (a *Analyzer) createDb(ctx context.Context, migrations []string) (string, error) { + hash := a.fnv(migrations) + name := fmt.Sprintf("sqlc_%s", hash) + + serverUri := a.replacer.Replace(a.servers[0].URI) + pool, err := a.serverCache.Open(ctx, serverUri) + if err != nil { + return "", err + } + + uri, err := url.Parse(serverUri) + if err != nil { + return "", err + } + uri.Path = name + + key := uri.String() + _, err, _ = a.flight.Do(key, func() (interface{}, error) { + // TODO: Use a parameterized query + row := pool.QueryRow(ctx, + fmt.Sprintf(`SELECT datname FROM pg_database WHERE datname = '%s'`, name)) + + var datname string + if err := row.Scan(&datname); err == nil { + slog.Info("database exists", "name", name) + return nil, nil + } + + slog.Info("creating database", "name", name) + if _, err := pool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil { + return nil, err + } + + conn, err := pgx.Connect(ctx, uri.String()) + if err != nil { + return nil, fmt.Errorf("connect %s: %s", name, err) + } + defer conn.Close(ctx) + + for _, q := range migrations { + if len(strings.TrimSpace(q)) == 0 { + continue + } + if _, err := conn.Exec(ctx, q); err != nil { + return nil, fmt.Errorf("%s: %s", q, err) + } + } + return nil, nil + }) + + if err != nil { + return "", err + } + + return key, err +} diff --git a/internal/pgx/createdb/createdb.go b/internal/pgx/createdb/createdb.go index 4690f9b0f9..b7a1a541a8 100644 --- a/internal/pgx/createdb/createdb.go +++ b/internal/pgx/createdb/createdb.go @@ -3,84 +3,53 @@ package createdb import ( "context" "fmt" - "hash/fnv" + "log/slog" "net/url" - "os" "strings" - "github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/sync/singleflight" - migrate "github.com/sqlc-dev/sqlc/internal/migrations" "github.com/sqlc-dev/sqlc/internal/pgx/poolcache" - "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" ) -var flight singleflight.Group -var cache = poolcache.New() - type Server struct { - pool *pgxpool.Pool + uri string + flight singleflight.Group + cache *poolcache.Cache } -func Create(ctx context.Context, srv *pgxpool.Pool, migrations []string) (*pgxpool.Pool, error) { - ctx := context.Background() - - dburi := os.Getenv("POSTGRESQL_SERVER_URI") - if dburi == "" { - t.Skip("POSTGRESQL_SERVER_URI is empty") - } - - postgresPool, err := cache.Open(ctx, dburi) - if err != nil { - t.Fatalf("PostgreSQL pool creation failed: %s", err) - } - - var seed []string - files, err := sqlpath.Glob(migrations) - if err != nil { - t.Fatal(err) - } - - h := fnv.New64() - for _, f := range files { - blob, err := os.ReadFile(f) - if err != nil { - t.Fatal(err) - } - h.Write(blob) - seed = append(seed, migrate.RemoveRollbackStatements(string(blob))) +func New(uri string) *Server { + return &Server{ + uri: uri, + cache: poolcache.New(), } +} - var name string - if rw { - name = fmt.Sprintf("sqlc_test_%s", id()) - } else { - name = fmt.Sprintf("sqlc_test_%x", h.Sum(nil)) - } +func (s *Server) Create(ctx context.Context, hash string, migrations []string) (string, *pgxpool.Pool, error) { + name := fmt.Sprintf("sqlc_%s", hash) - uri, err := url.Parse(dburi) + uri, err := url.Parse(s.uri) if err != nil { - t.Fatal(err) + return "", nil, err } uri.Path = name - dropQuery := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s" WITH (FORCE)`, name) key := uri.String() - - _, err, _ = flight.Do(key, func() (interface{}, error) { - row := postgresPool.QueryRow(ctx, + _, err, _ = s.flight.Do(key, func() (interface{}, error) { + // TODO: Use a parameterized query + row := s.pool.QueryRow(ctx, fmt.Sprintf(`SELECT datname FROM pg_database WHERE datname = '%s'`, name)) var datname string if err := row.Scan(&datname); err == nil { - t.Logf("database exists: %s", name) + slog.Info("database exists", "name", name) return nil, nil } - t.Logf("creating database: %s", name) - if _, err := postgresPool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil { + slog.Info("creating database", "name", name) + if _, err := s.pool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil { return nil, err } @@ -90,7 +59,7 @@ func Create(ctx context.Context, srv *pgxpool.Pool, migrations []string) (*pgxpo } defer conn.Close(ctx) - for _, q := range seed { + for _, q := range migrations { if len(strings.TrimSpace(q)) == 0 { continue } @@ -100,15 +69,11 @@ func Create(ctx context.Context, srv *pgxpool.Pool, migrations []string) (*pgxpo } return nil, nil }) - if rw || err != nil { - t.Cleanup(func() { - if _, err := postgresPool.Exec(ctx, dropQuery); err != nil { - t.Fatalf("failed cleaning up: %s", err) - } - }) - } + if err != nil { - t.Fatalf("create db: %s", err) + return "", nil, err } - return key + + db, err := s.cache.Open(ctx, key) + return key, db, err } From 47867c5078c95a2a81ac8e6030d91002cf83615f Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 10 May 2024 15:40:24 -0700 Subject: [PATCH 3/5] Add database: auto --- internal/cmd/generate.go | 23 --------- internal/compiler/engine.go | 8 ---- internal/config/validate.go | 6 ++- internal/pgx/createdb/createdb.go | 79 ------------------------------- 4 files changed, 5 insertions(+), 111 deletions(-) delete mode 100644 internal/pgx/createdb/createdb.go diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index 78ec757917..8c21cac610 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "log" "os" "path/filepath" "runtime/trace" @@ -28,8 +27,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/info" "github.com/sqlc-dev/sqlc/internal/multierr" "github.com/sqlc-dev/sqlc/internal/opts" - "github.com/sqlc-dev/sqlc/internal/pgx/createdb" - "github.com/sqlc-dev/sqlc/internal/pgx/poolcache" "github.com/sqlc-dev/sqlc/internal/plugin" "github.com/sqlc-dev/sqlc/internal/remote" "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" @@ -320,26 +317,6 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C return nil, true } - { - uri := combo.Global.Servers[0].URI - cache := poolcache.New() - pool, err := cache.Open(ctx, uri) - if err != nil { - log.Println("cache.Open", err) - return nil, false - } - creator := createdb.New(uri, pool) - dburi, db, err := creator.Create(ctx, c.SchemaHash, c.Schema) - if err != nil { - log.Println("creator.Create", err) - } - fmt.Println(db) - - combo.Package.Database.URI = dburi - combo.Package.Database.Managed = false - c.UpdateAnalyzer(combo.Package.Database) - } - if parserOpts.Debug.DumpCatalog { debug.Dump(c.Catalog()) } diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index 9f80f6643e..b08fe94641 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -64,14 +64,6 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err return c, nil } -func (c *Compiler) UpdateAnalyzer(db *config.Database) { - c.analyzer = analyzer.Cached( - pganalyze.New(c.client, *db), - c.combo.Global, - *db, - ) -} - func (c *Compiler) Catalog() *catalog.Catalog { return c.catalog } diff --git a/internal/config/validate.go b/internal/config/validate.go index fadef4fb3b..afd0323eec 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -3,7 +3,11 @@ package config func Validate(c *Config) error { for _, sql := range c.SQL { if sql.Database != nil { - if sql.Database.URI == "" && !sql.Database.Managed { + switch { + case sql.Database.URI == "": + case sql.Database.Managed: + case sql.Database.Auto: + default: return ErrInvalidDatabase } } diff --git a/internal/pgx/createdb/createdb.go b/internal/pgx/createdb/createdb.go deleted file mode 100644 index b7a1a541a8..0000000000 --- a/internal/pgx/createdb/createdb.go +++ /dev/null @@ -1,79 +0,0 @@ -package createdb - -import ( - "context" - "fmt" - "log/slog" - "net/url" - "strings" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" - "golang.org/x/sync/singleflight" - - "github.com/sqlc-dev/sqlc/internal/pgx/poolcache" -) - -type Server struct { - uri string - flight singleflight.Group - cache *poolcache.Cache -} - -func New(uri string) *Server { - return &Server{ - uri: uri, - cache: poolcache.New(), - } -} - -func (s *Server) Create(ctx context.Context, hash string, migrations []string) (string, *pgxpool.Pool, error) { - name := fmt.Sprintf("sqlc_%s", hash) - - uri, err := url.Parse(s.uri) - if err != nil { - return "", nil, err - } - uri.Path = name - - key := uri.String() - _, err, _ = s.flight.Do(key, func() (interface{}, error) { - // TODO: Use a parameterized query - row := s.pool.QueryRow(ctx, - fmt.Sprintf(`SELECT datname FROM pg_database WHERE datname = '%s'`, name)) - - var datname string - if err := row.Scan(&datname); err == nil { - slog.Info("database exists", "name", name) - return nil, nil - } - - slog.Info("creating database", "name", name) - if _, err := s.pool.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, name)); err != nil { - return nil, err - } - - conn, err := pgx.Connect(ctx, uri.String()) - if err != nil { - return nil, fmt.Errorf("connect %s: %s", name, err) - } - defer conn.Close(ctx) - - for _, q := range migrations { - if len(strings.TrimSpace(q)) == 0 { - continue - } - if _, err := conn.Exec(ctx, q); err != nil { - return nil, fmt.Errorf("%s: %s", q, err) - } - } - return nil, nil - }) - - if err != nil { - return "", nil, err - } - - db, err := s.cache.Open(ctx, key) - return key, db, err -} From fda2e458c3d38978f9f1400b52557db5a3892c5c Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 10 May 2024 16:02:43 -0700 Subject: [PATCH 4/5] Try and use auto for tests --- Makefile | 3 +++ internal/endtoend/endtoend_test.go | 21 +++++++++++++-------- internal/sqltest/local/mysql.go | 4 ++++ internal/sqltest/local/postgres.go | 4 ++++ 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index b4b7e80bcf..0632af13ab 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,9 @@ install: test: go test ./... +test-managed: + MYSQL_SERVER_URI="invalid" POSTGRESQL_SERVER_URI="postgres://postgres:mysecretpassword@localhost:5432/postgres" go test ./... + vet: go vet ./... diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index e766e42359..c16ca259a4 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -120,21 +120,26 @@ func TestReplay(t *testing.T) { "managed-db": { Mutate: func(t *testing.T, path string) func(*config.Config) { return func(c *config.Config) { + c.Servers = []config.Server{ + { + Name: "postgres", + URI: local.PostgreSQLServer(), + }, + + { + Name: "mysql", + URI: local.MySQLServer(), + }, + } for i := range c.SQL { - files := []string{} - for _, s := range c.SQL[i].Schema { - files = append(files, filepath.Join(path, s)) - } switch c.SQL[i].Engine { case config.EnginePostgreSQL: - uri := local.ReadOnlyPostgreSQL(t, files) c.SQL[i].Database = &config.Database{ - URI: uri, + Auto: true, } case config.EngineMySQL: - uri := local.MySQL(t, files) c.SQL[i].Database = &config.Database{ - URI: uri, + Auto: true, } default: // pass diff --git a/internal/sqltest/local/mysql.go b/internal/sqltest/local/mysql.go index c61cee3418..9c068a39ba 100644 --- a/internal/sqltest/local/mysql.go +++ b/internal/sqltest/local/mysql.go @@ -18,6 +18,10 @@ import ( var mysqlSync sync.Once var mysqlPool *sql.DB +func MySQLServer() string { + return os.Getenv("MYSQL_SERVER_URI") +} + func MySQL(t *testing.T, migrations []string) string { ctx := context.Background() t.Helper() diff --git a/internal/sqltest/local/postgres.go b/internal/sqltest/local/postgres.go index 84a0c3d956..7b2c16c40a 100644 --- a/internal/sqltest/local/postgres.go +++ b/internal/sqltest/local/postgres.go @@ -28,6 +28,10 @@ func ReadOnlyPostgreSQL(t *testing.T, migrations []string) string { return postgreSQL(t, migrations, false) } +func PostgreSQLServer() string { + return os.Getenv("POSTGRESQL_SERVER_URI") +} + func postgreSQL(t *testing.T, migrations []string, rw bool) string { ctx := context.Background() t.Helper() From c56000d41f0cbf1aa0b5658f25014b3ba8ada699 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 10 May 2024 16:04:00 -0700 Subject: [PATCH 5/5] Fix tests --- internal/config/config_test.go | 1 + internal/config/validate.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 57211d674c..ecb78e74a2 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -82,6 +82,7 @@ func TestInvalidConfig(t *testing.T) { Database: &Database{ URI: "", Managed: false, + Auto: false, }, }}, }) diff --git a/internal/config/validate.go b/internal/config/validate.go index afd0323eec..a30202d329 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -4,7 +4,7 @@ func Validate(c *Config) error { for _, sql := range c.SQL { if sql.Database != nil { switch { - case sql.Database.URI == "": + case sql.Database.URI != "": case sql.Database.Managed: case sql.Database.Auto: default: