From d91d26aeb9141ae5ea30b776da56695f78f0ebd1 Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Tue, 21 Dec 2021 17:15:16 +0100 Subject: [PATCH 1/5] Update the development guide to reflect the new regenerate script --- docs/guides/development.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/guides/development.md b/docs/guides/development.md index cf84beced6..5140f086eb 100644 --- a/docs/guides/development.md +++ b/docs/guides/development.md @@ -57,10 +57,11 @@ MYSQL_DATABASE dinotest ## Regenerate expected test output If you need to update a large number of expected test output in the -`internal/endtoend/testdata` directory, run the `regenerate.sh` script. +`internal/endtoend/testdata` directory, run the `regenerate` script. ``` -make regen +go build -o ~/go/bin/sqlc-dev ./cmd/sqlc +go run scripts/regenerate/main.go ``` Note that this uses the `sqlc-dev` binary, not `sqlc` so make sure you have an From 98eaf232958447d072672f26da7f1b4d2f99e8fb Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Tue, 21 Dec 2021 17:29:44 +0100 Subject: [PATCH 2/5] Fix small typo in the docs --- docs/howto/insert.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/howto/insert.md b/docs/howto/insert.md index 6e18fe11d7..f47ecf00c4 100644 --- a/docs/howto/insert.md +++ b/docs/howto/insert.md @@ -124,13 +124,13 @@ func (q *Queries) DeleteID(ctx context.Context, id int) (int, error) { return i, err } -const deleteAuhtor = `-- name: DeleteAuthor :one +const deleteAuthor = `-- name: DeleteAuthor :one DELETE FROM authors WHERE id = $1 RETURNING id, bio ` func (q *Queries) DeleteAuthor(ctx context.Context, id int) (Author, error) { - row := q.db.QueryRowContext(ctx, deleteAuhtor, id) + row := q.db.QueryRowContext(ctx, deleteAuthor, id) var i Author err := row.Scan(&i.ID, &i.Bio) return i, err From 6e18a48cbed9d6caae7c7c54419fab8d83f1a37d Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Tue, 21 Dec 2021 17:27:25 +0100 Subject: [PATCH 3/5] Implement support for pgx's CopyFrom This allows type-safe bulk loading with great performance. I didn't implement it for Python and Kotlin. This change is fully backwards compatible, as it only changes the DBTX interface when someone adds their first :copyFrom query (at which point it's reasonable to require the CopyFrom method on the DBTX). --- docs/howto/insert.md | 26 ++++++ internal/codegen/golang/field.go | 3 +- internal/codegen/golang/gen.go | 12 +++ internal/codegen/golang/query.go | 32 ++++++++ internal/codegen/golang/result.go | 8 +- .../codegen/golang/templates/pgx/dbCode.tmpl | 3 + .../golang/templates/pgx/interfaceCode.tmpl | 5 ++ .../golang/templates/pgx/queryCode.tmpl | 49 ++++++++++++ internal/codegen/kotlin/gen.go | 14 +++- internal/codegen/python/gen.go | 14 +++- internal/compiler/parse.go | 19 +++-- internal/compiler/query.go | 5 +- .../testdata/copyfrom/postgresql/pgx/go/db.go | 31 ++++++++ .../copyfrom/postgresql/pgx/go/models.go | 12 +++ .../copyfrom/postgresql/pgx/go/query.sql.go | 79 +++++++++++++++++++ .../copyfrom/postgresql/pgx/query.sql | 8 ++ .../copyfrom/postgresql/pgx/sqlc.json | 13 +++ internal/metadata/meta.go | 5 +- internal/sql/validate/cmd.go | 44 ++++++++++- 19 files changed, 361 insertions(+), 21 deletions(-) create mode 100644 internal/endtoend/testdata/copyfrom/postgresql/pgx/go/db.go create mode 100644 internal/endtoend/testdata/copyfrom/postgresql/pgx/go/models.go create mode 100644 internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go create mode 100644 internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql create mode 100644 internal/endtoend/testdata/copyfrom/postgresql/pgx/sqlc.json diff --git a/docs/howto/insert.md b/docs/howto/insert.md index f47ecf00c4..6a9224ff87 100644 --- a/docs/howto/insert.md +++ b/docs/howto/insert.md @@ -136,3 +136,29 @@ func (q *Queries) DeleteAuthor(ctx context.Context, id int) (Author, error) { return i, err } ``` + +## Using CopyFrom + +PostgreSQL supports the Copy Protocol that can insert rows a lot faster than sequential inserts. You can use this easily with sqlc: + +```sql +CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name text NOT NULL, + bio text NOT NULL +); + +-- name: CreateAuthors :copyFrom +INSERT INTO authors (name, bio) VALUES ($1, $2); +``` + +```go +type CreateAuthorsParams struct { + Name string + Bio string +} + +func (q *Queries) CreateAuthors(ctx context.Context, arg []CreateAuthorsParams) (int64, error) { + return q.db.CopyFrom(ctx, []string{"authors"}, []string{"name", "bio"}, &iteratorForCreateAuthors{rows: arg}) +} +``` diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index e036b0e041..bb095e8a62 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -9,7 +9,8 @@ import ( ) type Field struct { - Name string + Name string // CamelCased name for Go + DBName string // Name as used in the DB Type string Tags map[string]string Comment string diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index ca548e3e58..890abaa806 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -11,6 +11,7 @@ import ( "github.com/kyleconroy/sqlc/internal/codegen" "github.com/kyleconroy/sqlc/internal/compiler" "github.com/kyleconroy/sqlc/internal/config" + "github.com/kyleconroy/sqlc/internal/metadata" ) type Generateable interface { @@ -37,6 +38,7 @@ type tmplCtx struct { EmitInterface bool EmitEmptySlices bool EmitMethodsWithDBArgument bool + UsesCopyFrom bool } func (t *tmplCtx) OutputQuery(sourceName string) bool { @@ -87,6 +89,7 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, EmitPreparedQueries: golang.EmitPreparedQueries, EmitEmptySlices: golang.EmitEmptySlices, EmitMethodsWithDBArgument: golang.EmitMethodsWithDBArgument, + UsesCopyFrom: usesCopyFrom(queries), SQLPackage: SQLPackageFromString(golang.SQLPackage), Q: "`", Package: golang.Package, @@ -160,3 +163,12 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, } return output, nil } + +func usesCopyFrom(queries []Query) bool { + for _, q := range queries { + if q.Cmd == metadata.CmdCopyFrom { + return true + } + } + return false +} diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 171125adc0..dfd332f0f7 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -1,9 +1,11 @@ package golang import ( + "fmt" "strings" "github.com/kyleconroy/sqlc/internal/metadata" + "github.com/kyleconroy/sqlc/internal/sql/ast" ) type QueryValue struct { @@ -38,6 +40,13 @@ func (v QueryValue) Pair() string { return v.Name + " " + v.DefineType() } +func (v QueryValue) SlicePair() string { + if v.isEmpty() { + return "" + } + return v.Name + " []" + v.DefineType() +} + func (v QueryValue) Type() string { if v.Typ != "" { return v.Typ @@ -105,6 +114,17 @@ func (v QueryValue) Params() string { return "\n" + strings.Join(out, ",\n") } +func (v QueryValue) ColumnNames() string { + if v.Struct == nil { + return fmt.Sprintf("[]string{%q}", v.Name) + } + escapedNames := make([]string, len(v.Struct.Fields)) + for i, f := range v.Struct.Fields { + escapedNames[i] = fmt.Sprintf("%q", f.DBName) + } + return "[]string{" + strings.Join(escapedNames, ", ") + "}" +} + func (v QueryValue) Scan() string { var out []string if v.Struct == nil { @@ -140,9 +160,21 @@ type Query struct { SourceName string Ret QueryValue Arg QueryValue + // Used for :copyFrom + Table *ast.TableName } func (q Query) hasRetType() bool { scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany return scanned && !q.Ret.isEmpty() } + +func (q Query) TableIdentifier() string { + escapedNames := make([]string, 0, 3) + for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { + if p != "" { + escapedNames = append(escapedNames, fmt.Sprintf("%q", p)) + } + } + return "[]string{" + strings.Join(escapedNames, ", ") + "}" +} diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index b213826eb9..28d1c0c2e8 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -160,6 +160,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs SourceName: query.Filename, SQL: query.SQL, Comments: query.Comments, + Table: query.InsertIntoTable, } sqlpkg := SQLPackageFromString(settings.Go.SQLPackage) @@ -291,9 +292,10 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin tags["json:"] = JSONTagName(tagName, settings) } gs.Fields = append(gs.Fields, Field{ - Name: fieldName, - Type: goType(r, c.Column, settings), - Tags: tags, + Name: fieldName, + DBName: colName, + Type: goType(r, c.Column, settings), + Tags: tags, }) if _, found := seen[baseFieldName]; !found { seen[baseFieldName] = []int{i} diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index dbfde50d50..00d624f8b9 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -4,6 +4,9 @@ type DBTX interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row +{{- if .UsesCopyFrom }} + CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) +{{- end }} } {{ if .EmitMethodsWithDBArgument}} diff --git a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl index d8ab4d7bb9..0d4940e74e 100644 --- a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl +++ b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl @@ -27,6 +27,11 @@ {{- else if eq .Cmd ":execresult" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {{- end}} + {{- if and (eq .Cmd ":copyFrom") ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) + {{- else if eq .Cmd ":copyFrom" }} + {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) + {{- end}} {{- end}} } diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index 21575b3e70..ebf11505ca 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -1,9 +1,11 @@ {{define "queryCodePgx"}} {{range .GoQueries}} {{if $.OutputQuery .SourceName}} +{{if ne .Cmd ":copyFrom"}} const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} {{escape .SQL}} {{$.Q}} +{{end}} {{if .Arg.EmitStruct}} type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} @@ -112,6 +114,53 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.Co } {{end}} +{{if eq .Cmd ":copyFrom"}} +// iteratorFor{{.MethodName}} implements pgx.CopyFromSource. +type iteratorFor{{.MethodName}} struct { + rows []{{.Arg.DefineType}} + skippedFirstNextCall bool +} + +func (r *iteratorFor{{.MethodName}}) Next() bool { + if len(r.rows) == 0 { + return false + } + if !r.skippedFirstNextCall { + r.skippedFirstNextCall = true + return true + } + r.rows = r.rows[1:] + return len(r.rows) > 0 +} + +func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) { + return []interface{}{ +{{- if .Arg.Struct }} +{{- range .Arg.Struct.Fields }} + r.rows[0].{{.Name}}, +{{- end }} +{{- else }} + r.rows[0], +{{- end }} + }, nil +} + +func (r iteratorFor{{.MethodName}}) Err() error { + return nil +} + +{{range .Comments}}//{{.}} +{{end -}} +{{- if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { + return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) +{{- else}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { + return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) +{{- end}} +} +{{end}} + {{end}} {{end}} {{end}} diff --git a/internal/codegen/kotlin/gen.go b/internal/codegen/kotlin/gen.go index 48c23f7504..7c39b19e2f 100644 --- a/internal/codegen/kotlin/gen.go +++ b/internal/codegen/kotlin/gen.go @@ -3,6 +3,7 @@ package kotlin import ( "bufio" "bytes" + "errors" "fmt" "regexp" "sort" @@ -14,6 +15,7 @@ import ( "github.com/kyleconroy/sqlc/internal/config" "github.com/kyleconroy/sqlc/internal/core" "github.com/kyleconroy/sqlc/internal/inflection" + "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/catalog" ) @@ -458,7 +460,7 @@ func jdbcSQL(s string, engine config.Engine) string { return s } -func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query { +func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) ([]Query, error) { qs := make([]Query, 0, len(r.Queries)) for _, query := range r.Queries { if query.Name == "" { @@ -467,6 +469,9 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs if query.Cmd == "" { continue } + if query.Cmd == metadata.CmdCopyFrom { + return nil, errors.New("Support for CopyFrom in Kotlin is not implemented") + } gq := Query{ Cmd: query.Cmd, @@ -543,7 +548,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs qs = append(qs, gq) } sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) - return qs + return qs, nil } var ktIfaceTmpl = `// Code generated by sqlc. DO NOT EDIT. @@ -769,7 +774,10 @@ func ktFormat(s string) string { func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) { enums := buildEnums(r, settings) structs := buildDataClasses(r, settings) - queries := buildQueries(r, settings, structs) + queries, err := buildQueries(r, settings, structs) + if err != nil { + return nil, err + } i := &importer{ Settings: settings, diff --git a/internal/codegen/python/gen.go b/internal/codegen/python/gen.go index 85186daa8c..408aeb3855 100644 --- a/internal/codegen/python/gen.go +++ b/internal/codegen/python/gen.go @@ -1,6 +1,7 @@ package python import ( + "errors" "fmt" "log" "regexp" @@ -12,6 +13,7 @@ import ( "github.com/kyleconroy/sqlc/internal/config" "github.com/kyleconroy/sqlc/internal/core" "github.com/kyleconroy/sqlc/internal/inflection" + "github.com/kyleconroy/sqlc/internal/metadata" pyast "github.com/kyleconroy/sqlc/internal/python/ast" "github.com/kyleconroy/sqlc/internal/python/poet" pyprint "github.com/kyleconroy/sqlc/internal/python/printer" @@ -390,7 +392,7 @@ func sqlalchemySQL(s string, engine config.Engine) string { return s } -func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query { +func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) ([]Query, error) { qs := make([]Query, 0, len(r.Queries)) for _, query := range r.Queries { if query.Name == "" { @@ -399,6 +401,9 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs if query.Cmd == "" { continue } + if query.Cmd == metadata.CmdCopyFrom { + return nil, errors.New("Support for CopyFrom in Python is not implemented") + } methodName := MethodName(query.Name) @@ -490,7 +495,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs qs = append(qs, gq) } sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) - return qs + return qs, nil } func importNode(name string) *pyast.Node { @@ -1052,7 +1057,10 @@ func HashComment(s string) string { func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) { enums := buildEnums(r, settings) models := buildModels(r, settings) - queries := buildQueries(r, settings, models) + queries, err := buildQueries(r, settings, models) + if err != nil { + return nil, err + } i := &importer{ Settings: settings, diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index de47b9cd68..6123482881 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -46,6 +46,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if !ok { return nil, errors.New("node is not a statement") } + var table *ast.TableName switch n := raw.Stmt.(type) { case *ast.SelectStmt: case *ast.DeleteStmt: @@ -53,6 +54,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err := validate.InsertStmt(n); err != nil { return nil, err } + var err error + table, err = ParseTableName(n.Relation) + if err != nil { + return nil, err + } case *ast.TruncateStmt: case *ast.UpdateStmt: default: @@ -132,12 +138,13 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } return &Query{ - Cmd: cmd, - Comments: comments, - Name: name, - Params: params, - Columns: cols, - SQL: trimmed, + Cmd: cmd, + Comments: comments, + Name: name, + Params: params, + Columns: cols, + SQL: trimmed, + InsertIntoTable: table, }, nil } diff --git a/internal/compiler/query.go b/internal/compiler/query.go index d2eb1d2fd7..1751f4695a 100644 --- a/internal/compiler/query.go +++ b/internal/compiler/query.go @@ -35,13 +35,16 @@ type Column struct { type Query struct { SQL string Name string - Cmd string // TODO: Pick a better name. One of: one, many, exec, execrows + Cmd string // TODO: Pick a better name. One of: one, many, exec, execrows, copyFrom Columns []*Column Params []Parameter Comments []string // XXX: Hack Filename string + + // Needed for CopyFrom + InsertIntoTable *ast.TableName } type Parameter struct { diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/db.go b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..062b72ba02 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row + CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/models.go b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..319a4b26a2 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/models.go @@ -0,0 +1,12 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type MyschemaFoo struct { + A sql.NullString + B sql.NullInt32 +} diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..141f47e6c4 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go @@ -0,0 +1,79 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +// iteratorForInsertSingleValue implements pgx.CopyFromSource. +type iteratorForInsertSingleValue struct { + rows []sql.NullString + skippedFirstNextCall bool +} + +func (r *iteratorForInsertSingleValue) Next() bool { + if len(r.rows) == 0 { + return false + } + if !r.skippedFirstNextCall { + r.skippedFirstNextCall = true + return true + } + r.rows = r.rows[1:] + return len(r.rows) > 0 +} + +func (r iteratorForInsertSingleValue) Values() ([]interface{}, error) { + return []interface{}{ + r.rows[0], + }, nil +} + +func (r iteratorForInsertSingleValue) Err() error { + return nil +} + +func (q *Queries) InsertSingleValue(ctx context.Context, a []sql.NullString) (int64, error) { + return q.db.CopyFrom(ctx, []string{"myschema", "foo"}, []string{"a"}, &iteratorForInsertSingleValue{rows: a}) +} + +type InsertValuesParams struct { + A sql.NullString + B sql.NullInt32 +} + +// iteratorForInsertValues implements pgx.CopyFromSource. +type iteratorForInsertValues struct { + rows []InsertValuesParams + skippedFirstNextCall bool +} + +func (r *iteratorForInsertValues) Next() bool { + if len(r.rows) == 0 { + return false + } + if !r.skippedFirstNextCall { + r.skippedFirstNextCall = true + return true + } + r.rows = r.rows[1:] + return len(r.rows) > 0 +} + +func (r iteratorForInsertValues) Values() ([]interface{}, error) { + return []interface{}{ + r.rows[0].A, + r.rows[0].B, + }, nil +} + +func (r iteratorForInsertValues) Err() error { + return nil +} + +func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) (int64, error) { + return q.db.CopyFrom(ctx, []string{"myschema", "foo"}, []string{"a", "b"}, &iteratorForInsertValues{rows: arg}) +} diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql b/internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql new file mode 100644 index 0000000000..4c4e6f5012 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql @@ -0,0 +1,8 @@ +CREATE SCHEMA myschema; +CREATE TABLE myschema.foo (a text, b integer); + +-- name: InsertValues :copyFrom +INSERT INTO myschema.foo (a, b) VALUES ($1, $2); + +-- name: InsertSingleValue :copyFrom +INSERT INTO myschema.foo (a) VALUES ($1); diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/copyfrom/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index f51db0c763..88042e57e4 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -18,6 +18,7 @@ const ( CmdExecRows = ":execrows" CmdMany = ":many" CmdOne = ":one" + CmdCopyFrom = ":copyFrom" ) // A query name must be a valid Go identifier @@ -79,7 +80,7 @@ func Parse(t string, commentStyle CommentSyntax) (string, string, error) { part = part[:len(part)-1] // removes the trailing "*/" element } if len(part) == 2 { - return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execresult']: %s", line) + return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execresult', ':copyFrom']: %s", line) } if len(part) != 4 { return "", "", fmt.Errorf("invalid query comment: %s", line) @@ -87,7 +88,7 @@ func Parse(t string, commentStyle CommentSyntax) (string, string, error) { queryName := part[2] queryType := strings.TrimSpace(part[3]) switch queryType { - case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows: + case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdCopyFrom: default: return "", "", fmt.Errorf("invalid query type: %s", queryType) } diff --git a/internal/sql/validate/cmd.go b/internal/sql/validate/cmd.go index 6c9a60b298..7352d5db49 100644 --- a/internal/sql/validate/cmd.go +++ b/internal/sql/validate/cmd.go @@ -1,14 +1,54 @@ package validate import ( + "errors" "fmt" + "github.com/kyleconroy/sqlc/internal/metadata" "github.com/kyleconroy/sqlc/internal/sql/ast" ) +func validateCopyfrom(n ast.Node) error { + stmt, ok := n.(*ast.InsertStmt) + if !ok { + return errors.New(":copyFrom requires an INSERT INTO statement") + } + if stmt.OnConflictClause != nil { + return errors.New(":copyFrom is not compatible with ON CONFLICT") + } + if stmt.WithClause != nil { + return errors.New(":copyFrom is not compatible with WITH clauses") + } + if stmt.ReturningList != nil && len(stmt.ReturningList.Items) > 0 { + return errors.New(":copyFrom is not compatible with RETURNING") + } + sel, ok := stmt.SelectStmt.(*ast.SelectStmt) + if !ok { + return nil + } + if len(sel.FromClause.Items) > 0 { + return errors.New(":copyFrom is not compatible with INSERT INTO ... SELECT") + } + if sel.ValuesLists == nil || len(sel.ValuesLists.Items) != 1 { + return errors.New(":copyFrom requires exactly one example row to be inserted") + } + sublist, ok := sel.ValuesLists.Items[0].(*ast.List) + if !ok { + return nil + } + for _, v := range sublist.Items { + if _, ok := v.(*ast.ParamRef); !ok { + return errors.New(":copyFrom doesn't support non-parameter values") + } + } + return nil +} + func Cmd(n ast.Node, name, cmd string) error { - // TODO: Convert cmd to an enum - if !(cmd == ":many" || cmd == ":one") { + if cmd == metadata.CmdCopyFrom { + return validateCopyfrom(n) + } + if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne) { return nil } var list *ast.List From 93b48ee2f2e85cc7d2f14e75c9f7f98c8d1d89dd Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Thu, 13 Jan 2022 11:59:57 +0100 Subject: [PATCH 4/5] review comments on PR 1352 --- docs/howto/insert.md | 2 +- internal/codegen/golang/gen.go | 12 +++ internal/codegen/golang/query.go | 2 +- .../golang/templates/pgx/copyfromCopy.tmpl | 51 +++++++++++++ .../golang/templates/pgx/interfaceCode.tmpl | 4 +- .../golang/templates/pgx/queryCode.tmpl | 48 +----------- .../codegen/golang/templates/template.tmpl | 21 ++++++ .../copyfrom/postgresql/pgx/go/copyfrom.go | 73 +++++++++++++++++++ .../copyfrom/postgresql/pgx/go/query.sql.go | 65 ----------------- .../copyfrom/postgresql/pgx/query.sql | 4 +- internal/metadata/meta.go | 4 +- internal/sql/validate/cmd.go | 14 ++-- 12 files changed, 173 insertions(+), 127 deletions(-) create mode 100644 internal/codegen/golang/templates/pgx/copyfromCopy.tmpl create mode 100644 internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go diff --git a/docs/howto/insert.md b/docs/howto/insert.md index 6a9224ff87..e78dc7f2ab 100644 --- a/docs/howto/insert.md +++ b/docs/howto/insert.md @@ -148,7 +148,7 @@ CREATE TABLE authors ( bio text NOT NULL ); --- name: CreateAuthors :copyFrom +-- name: CreateAuthors :copyfrom INSERT INTO authors (name, bio) VALUES ($1, $2); ``` diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 890abaa806..7e427d0356 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -3,6 +3,7 @@ package golang import ( "bufio" "bytes" + "errors" "fmt" "go/format" "strings" @@ -98,6 +99,10 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, Structs: structs, } + if tctx.UsesCopyFrom && tctx.SQLPackage != SQLPackagePGX { + return nil, errors.New(":copyfrom is only supported by pgx") + } + output := map[string]string{} execute := func(name, templateName string) error { @@ -138,6 +143,8 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, if golang.OutputQuerierFileName != "" { querierFileName = golang.OutputQuerierFileName } + copyfromFileName := "copyfrom.go" + // TODO(Jille): Make this configurable. if err := execute(dbFileName, "dbFile"); err != nil { return nil, err @@ -150,6 +157,11 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, return nil, err } } + if tctx.UsesCopyFrom { + if err := execute(copyfromFileName, "copyfromFile"); err != nil { + return nil, err + } + } files := map[string]struct{}{} for _, gq := range queries { diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index dfd332f0f7..f4ce648cd3 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -160,7 +160,7 @@ type Query struct { SourceName string Ret QueryValue Arg QueryValue - // Used for :copyFrom + // Used for :copyfrom Table *ast.TableName } diff --git a/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl b/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl new file mode 100644 index 0000000000..e637b2356b --- /dev/null +++ b/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl @@ -0,0 +1,51 @@ +{{define "copyfromCodePgx"}} +{{range .GoQueries}} +{{if eq .Cmd ":copyfrom" }} +// iteratorFor{{.MethodName}} implements pgx.CopyFromSource. +type iteratorFor{{.MethodName}} struct { + rows []{{.Arg.DefineType}} + skippedFirstNextCall bool +} + +func (r *iteratorFor{{.MethodName}}) Next() bool { + if len(r.rows) == 0 { + return false + } + if !r.skippedFirstNextCall { + r.skippedFirstNextCall = true + return true + } + r.rows = r.rows[1:] + return len(r.rows) > 0 +} + +func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) { + return []interface{}{ +{{- if .Arg.Struct }} +{{- range .Arg.Struct.Fields }} + r.rows[0].{{.Name}}, +{{- end }} +{{- else }} + r.rows[0], +{{- end }} + }, nil +} + +func (r iteratorFor{{.MethodName}}) Err() error { + return nil +} + +{{range .Comments}}//{{.}} +{{end -}} +{{- if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { + return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) +{{- else}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { + return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) +{{- end}} +} + +{{end}} +{{end}} +{{end}} diff --git a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl index 0d4940e74e..510332a0b4 100644 --- a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl +++ b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl @@ -27,9 +27,9 @@ {{- else if eq .Cmd ":execresult" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {{- end}} - {{- if and (eq .Cmd ":copyFrom") ($dbtxParam) }} + {{- if and (eq .Cmd ":copyfrom") ($dbtxParam) }} {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) - {{- else if eq .Cmd ":copyFrom" }} + {{- else if eq .Cmd ":copyfrom" }} {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) {{- end}} {{- end}} diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index ebf11505ca..cd7501e824 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -1,7 +1,7 @@ {{define "queryCodePgx"}} {{range .GoQueries}} {{if $.OutputQuery .SourceName}} -{{if ne .Cmd ":copyFrom"}} +{{if ne .Cmd ":copyfrom"}} const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} {{escape .SQL}} {{$.Q}} @@ -114,52 +114,6 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.Co } {{end}} -{{if eq .Cmd ":copyFrom"}} -// iteratorFor{{.MethodName}} implements pgx.CopyFromSource. -type iteratorFor{{.MethodName}} struct { - rows []{{.Arg.DefineType}} - skippedFirstNextCall bool -} - -func (r *iteratorFor{{.MethodName}}) Next() bool { - if len(r.rows) == 0 { - return false - } - if !r.skippedFirstNextCall { - r.skippedFirstNextCall = true - return true - } - r.rows = r.rows[1:] - return len(r.rows) > 0 -} - -func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) { - return []interface{}{ -{{- if .Arg.Struct }} -{{- range .Arg.Struct.Fields }} - r.rows[0].{{.Name}}, -{{- end }} -{{- else }} - r.rows[0], -{{- end }} - }, nil -} - -func (r iteratorFor{{.MethodName}}) Err() error { - return nil -} - -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { - return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) -{{- else}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { - return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) -{{- end}} -} -{{end}} {{end}} {{end}} diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index 53a2323b3b..b0d7d3f933 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -116,3 +116,24 @@ import ( {{- template "queryCodeStd" .}} {{end}} {{end}} + +{{define "copyfromFile"}}// Code generated by sqlc. DO NOT EDIT. +// source: {{.SourceName}} + +package {{.Package}} + +import ( + {{range imports .SourceName}} + {{range .}}{{.}} + {{end}} + {{end}} +) + +{{template "copyfromCode" . }} +{{end}} + +{{define "copyfromCode"}} +{{if eq .SQLPackage "pgx/v4"}} + {{- template "copyfromCodePgx" .}} +{{end}} +{{end}} diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go new file mode 100644 index 0000000000..8201aaf968 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go @@ -0,0 +1,73 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: copyfrom.go + +package querytest + +import ( + "context" +) + +// iteratorForInsertSingleValue implements pgx.CopyFromSource. +type iteratorForInsertSingleValue struct { + rows []sql.NullString + skippedFirstNextCall bool +} + +func (r *iteratorForInsertSingleValue) Next() bool { + if len(r.rows) == 0 { + return false + } + if !r.skippedFirstNextCall { + r.skippedFirstNextCall = true + return true + } + r.rows = r.rows[1:] + return len(r.rows) > 0 +} + +func (r iteratorForInsertSingleValue) Values() ([]interface{}, error) { + return []interface{}{ + r.rows[0], + }, nil +} + +func (r iteratorForInsertSingleValue) Err() error { + return nil +} + +func (q *Queries) InsertSingleValue(ctx context.Context, a []sql.NullString) (int64, error) { + return q.db.CopyFrom(ctx, []string{"myschema", "foo"}, []string{"a"}, &iteratorForInsertSingleValue{rows: a}) +} + +// iteratorForInsertValues implements pgx.CopyFromSource. +type iteratorForInsertValues struct { + rows []InsertValuesParams + skippedFirstNextCall bool +} + +func (r *iteratorForInsertValues) Next() bool { + if len(r.rows) == 0 { + return false + } + if !r.skippedFirstNextCall { + r.skippedFirstNextCall = true + return true + } + r.rows = r.rows[1:] + return len(r.rows) > 0 +} + +func (r iteratorForInsertValues) Values() ([]interface{}, error) { + return []interface{}{ + r.rows[0].A, + r.rows[0].B, + }, nil +} + +func (r iteratorForInsertValues) Err() error { + return nil +} + +func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) (int64, error) { + return q.db.CopyFrom(ctx, []string{"myschema", "foo"}, []string{"a", "b"}, &iteratorForInsertValues{rows: arg}) +} diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go index 141f47e6c4..465b92915e 100644 --- a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go @@ -8,72 +8,7 @@ import ( "database/sql" ) -// iteratorForInsertSingleValue implements pgx.CopyFromSource. -type iteratorForInsertSingleValue struct { - rows []sql.NullString - skippedFirstNextCall bool -} - -func (r *iteratorForInsertSingleValue) Next() bool { - if len(r.rows) == 0 { - return false - } - if !r.skippedFirstNextCall { - r.skippedFirstNextCall = true - return true - } - r.rows = r.rows[1:] - return len(r.rows) > 0 -} - -func (r iteratorForInsertSingleValue) Values() ([]interface{}, error) { - return []interface{}{ - r.rows[0], - }, nil -} - -func (r iteratorForInsertSingleValue) Err() error { - return nil -} - -func (q *Queries) InsertSingleValue(ctx context.Context, a []sql.NullString) (int64, error) { - return q.db.CopyFrom(ctx, []string{"myschema", "foo"}, []string{"a"}, &iteratorForInsertSingleValue{rows: a}) -} - type InsertValuesParams struct { A sql.NullString B sql.NullInt32 } - -// iteratorForInsertValues implements pgx.CopyFromSource. -type iteratorForInsertValues struct { - rows []InsertValuesParams - skippedFirstNextCall bool -} - -func (r *iteratorForInsertValues) Next() bool { - if len(r.rows) == 0 { - return false - } - if !r.skippedFirstNextCall { - r.skippedFirstNextCall = true - return true - } - r.rows = r.rows[1:] - return len(r.rows) > 0 -} - -func (r iteratorForInsertValues) Values() ([]interface{}, error) { - return []interface{}{ - r.rows[0].A, - r.rows[0].B, - }, nil -} - -func (r iteratorForInsertValues) Err() error { - return nil -} - -func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) (int64, error) { - return q.db.CopyFrom(ctx, []string{"myschema", "foo"}, []string{"a", "b"}, &iteratorForInsertValues{rows: arg}) -} diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql b/internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql index 4c4e6f5012..c622cef8ec 100644 --- a/internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/query.sql @@ -1,8 +1,8 @@ CREATE SCHEMA myschema; CREATE TABLE myschema.foo (a text, b integer); --- name: InsertValues :copyFrom +-- name: InsertValues :copyfrom INSERT INTO myschema.foo (a, b) VALUES ($1, $2); --- name: InsertSingleValue :copyFrom +-- name: InsertSingleValue :copyfrom INSERT INTO myschema.foo (a) VALUES ($1); diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index 88042e57e4..8fc2f3b32e 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -18,7 +18,7 @@ const ( CmdExecRows = ":execrows" CmdMany = ":many" CmdOne = ":one" - CmdCopyFrom = ":copyFrom" + CmdCopyFrom = ":copyfrom" ) // A query name must be a valid Go identifier @@ -80,7 +80,7 @@ func Parse(t string, commentStyle CommentSyntax) (string, string, error) { part = part[:len(part)-1] // removes the trailing "*/" element } if len(part) == 2 { - return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execresult', ':copyFrom']: %s", line) + return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execresult', ':copyfrom']: %s", line) } if len(part) != 4 { return "", "", fmt.Errorf("invalid query comment: %s", line) diff --git a/internal/sql/validate/cmd.go b/internal/sql/validate/cmd.go index 7352d5db49..aed44d4a97 100644 --- a/internal/sql/validate/cmd.go +++ b/internal/sql/validate/cmd.go @@ -11,26 +11,26 @@ import ( func validateCopyfrom(n ast.Node) error { stmt, ok := n.(*ast.InsertStmt) if !ok { - return errors.New(":copyFrom requires an INSERT INTO statement") + return errors.New(":copyfrom requires an INSERT INTO statement") } if stmt.OnConflictClause != nil { - return errors.New(":copyFrom is not compatible with ON CONFLICT") + return errors.New(":copyfrom is not compatible with ON CONFLICT") } if stmt.WithClause != nil { - return errors.New(":copyFrom is not compatible with WITH clauses") + return errors.New(":copyfrom is not compatible with WITH clauses") } if stmt.ReturningList != nil && len(stmt.ReturningList.Items) > 0 { - return errors.New(":copyFrom is not compatible with RETURNING") + return errors.New(":copyfrom is not compatible with RETURNING") } sel, ok := stmt.SelectStmt.(*ast.SelectStmt) if !ok { return nil } if len(sel.FromClause.Items) > 0 { - return errors.New(":copyFrom is not compatible with INSERT INTO ... SELECT") + return errors.New(":copyfrom is not compatible with INSERT INTO ... SELECT") } if sel.ValuesLists == nil || len(sel.ValuesLists.Items) != 1 { - return errors.New(":copyFrom requires exactly one example row to be inserted") + return errors.New(":copyfrom requires exactly one example row to be inserted") } sublist, ok := sel.ValuesLists.Items[0].(*ast.List) if !ok { @@ -38,7 +38,7 @@ func validateCopyfrom(n ast.Node) error { } for _, v := range sublist.Items { if _, ok := v.(*ast.ParamRef); !ok { - return errors.New(":copyFrom doesn't support non-parameter values") + return errors.New(":copyfrom doesn't support non-parameter values") } } return nil From 8597824caa269cbd7b1648eefc454810462feffc Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Thu, 13 Jan 2022 12:29:41 +0100 Subject: [PATCH 5/5] Fix the imports calculation for having copyfrom in a separate file --- internal/codegen/golang/imports.go | 11 ++++++++++- .../testdata/copyfrom/postgresql/pgx/go/copyfrom.go | 1 + .../testdata/copyfrom/postgresql/pgx/go/query.sql.go | 1 - 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 2c38a3c48e..516cbb3018 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -89,6 +89,7 @@ func (i *importer) Imports(filename string) [][]ImportSpec { if i.Settings.Go.OutputQuerierFileName != "" { querierFileName = i.Settings.Go.OutputQuerierFileName } + copyfromFileName := "copyfrom.go" switch filename { case dbFileName: @@ -97,6 +98,8 @@ func (i *importer) Imports(filename string) [][]ImportSpec { return mergeImports(i.modelImports()) case querierFileName: return mergeImports(i.interfaceImports()) + case copyfromFileName: + return mergeImports(i.interfaceImports()) default: return mergeImports(i.queryImports(filename)) } @@ -279,9 +282,13 @@ func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImp func (i *importer) queryImports(filename string) fileImports { var gq []Query + anyNonCopyFrom := false for _, query := range i.Queries { if query.SourceName == filename { gq = append(gq, query) + if query.Cmd != metadata.CmdCopyFrom { + anyNonCopyFrom = true + } } } @@ -349,7 +356,9 @@ func (i *importer) queryImports(filename string) fileImports { return false } - std["context"] = struct{}{} + if anyNonCopyFrom { + std["context"] = struct{}{} + } sqlpkg := SQLPackageFromString(i.Settings.Go.SQLPackage) if sliceScan() && sqlpkg != SQLPackagePGX { diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go index 8201aaf968..57fdd1f578 100644 --- a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/copyfrom.go @@ -5,6 +5,7 @@ package querytest import ( "context" + "database/sql" ) // iteratorForInsertSingleValue implements pgx.CopyFromSource. diff --git a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go index 465b92915e..7d63553255 100644 --- a/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/copyfrom/postgresql/pgx/go/query.sql.go @@ -4,7 +4,6 @@ package querytest import ( - "context" "database/sql" )