diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 017a326797..cc54036d1d 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -107,6 +107,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err != nil { return nil, err } + params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams) if err != nil { return nil, err diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 4551e26425..4074116e3b 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -7,6 +7,7 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/catalog" + "github.com/kyleconroy/sqlc/internal/sql/named" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) @@ -18,7 +19,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, names map[int]string) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet) ([]Parameter, error) { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -26,18 +27,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, var defaultTable *ast.TableName var tables []*ast.TableName - parameterName := func(n int, defaultName string) string { - if n, ok := names[n]; ok { - return n - } - return defaultName - } - - isNamedParam := func(n int) bool { - _, ok := names[n] - return ok - } - typeMap := map[string]map[string]map[string]*catalog.Column{} indexTable := func(table catalog.Table) error { tables = append(tables, table.Rel) @@ -92,24 +81,28 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, switch n := ref.parent.(type) { case *limitOffset: + defaultP := named.NewInferredParam("offset", true) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, "offset"), + Name: p.Name(), DataType: "integer", - NotNull: true, - IsNamedParam: isNamedParam(ref.ref.Number), + NotNull: p.NotNull(), + IsNamedParam: isNamed, }, }) case *limitCount: + defaultP := named.NewInferredParam("limit", true) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, "limit"), + Name: p.Name(), DataType: "integer", - NotNull: true, - IsNamedParam: isNamedParam(ref.ref.Number), + NotNull: p.NotNull(), + IsNamedParam: isNamed, }, }) @@ -127,12 +120,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if astutils.Join(n.Name, ".") == "||" { dataType = "string" } + + defaultP := named.NewParam("") + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, ""), + Name: p.Name(), DataType: dataType, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, + NotNull: p.NotNull(), }, }) continue @@ -185,16 +182,19 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if ref.name != "" { key = ref.name } + + defaultP := named.NewInferredParam(key, c.IsNotNull) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), - NotNull: c.IsNotNull, + NotNull: p.NotNull(), IsArray: c.IsArray, Length: c.Length, Table: table, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } @@ -242,15 +242,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } if c, ok := typeMap[schema][table.Name][key]; ok { + defaultP := named.NewInferredParam(key, c.IsNotNull) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), - NotNull: c.IsNotNull, + NotNull: p.NotNull(), IsArray: c.IsArray, Table: table, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } @@ -309,12 +311,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if argName != "" { defaultName = argName } + + defaultP := named.NewInferredParam(defaultName, false) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, defaultName), + Name: p.Name(), DataType: "any", - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, + NotNull: p.NotNull(), }, }) continue @@ -340,13 +346,15 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, paramName = funcName } + defaultP := named.NewInferredParam(paramName, true) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, paramName), + Name: p.Name(), DataType: dataType(paramType), - NotNull: true, - IsNamedParam: isNamedParam(ref.ref.Number), + NotNull: p.NotNull(), + IsNamedParam: isNamed, }, }) } @@ -399,16 +407,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } if c, ok := tableMap[key]; ok { + defaultP := named.NewInferredParam(key, c.IsNotNull) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), - NotNull: c.IsNotNull, + NotNull: p.NotNull(), IsArray: c.IsArray, Table: &ast.TableName{Schema: schema, Name: rel}, Length: c.Length, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } else { @@ -424,7 +434,11 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, return nil, fmt.Errorf("*ast.TypeCast has nil type name") } col := toColumn(n.TypeName) - col.Name = parameterName(ref.ref.Number, col.Name) + defaultP := named.NewInferredParam(col.Name, col.NotNull) + p, _ := params.FetchMerge(ref.ref.Number, defaultP) + + col.Name = p.Name() + col.NotNull = p.NotNull() a = append(a, Parameter{ Number: ref.ref.Number, Column: col, @@ -500,15 +514,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if ref.name != "" { key = ref.name } + defaultP := named.NewInferredParam(key, c.IsNotNull) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), NotNull: c.IsNotNull, IsArray: c.IsArray, Table: table, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/go/db.go b/internal/endtoend/testdata/sqlc_narg/mysql/go/db.go new file mode 100644 index 0000000000..36ef5f4f45 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/go/models.go b/internal/endtoend/testdata/sqlc_narg/mysql/go/models.go new file mode 100644 index 0000000000..faee232b20 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + Bar string + MaybeBar sql.NullString +} diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_narg/mysql/go/query.sql.go new file mode 100644 index 0000000000..db493107fc --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/go/query.sql.go @@ -0,0 +1,119 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const identOnNonNullable = `-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = ? +` + +func (q *Queries) IdentOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, identOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const identOnNullable = `-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = ? +` + +func (q *Queries) IdentOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, identOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNonNullable = `-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = ? +` + +func (q *Queries) StringOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, stringOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNullable = `-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = ? +` + +func (q *Queries) StringOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, stringOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/query.sql b/internal/endtoend/testdata/sqlc_narg/mysql/query.sql new file mode 100644 index 0000000000..634830cbdf --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/query.sql @@ -0,0 +1,13 @@ +CREATE TABLE foo (bar text not null, maybe_bar text); + +-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg(bar); + +-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg(maybe_bar); + +-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg('bar'); + +-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg('maybe_bar'); diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/sqlc.json b/internal/endtoend/testdata/sqlc_narg/mysql/sqlc.json new file mode 100644 index 0000000000..0657f4db83 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "mysql", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/db.go b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..b0157bd009 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..faee232b20 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + Bar string + MaybeBar sql.NullString +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..80509257f8 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/query.sql.go @@ -0,0 +1,107 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const identOnNonNullable = `-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) IdentOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.Query(ctx, identOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const identOnNullable = `-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) IdentOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.Query(ctx, identOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNonNullable = `-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) StringOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.Query(ctx, stringOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNullable = `-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) StringOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.Query(ctx, stringOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/query.sql new file mode 100644 index 0000000000..634830cbdf --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/query.sql @@ -0,0 +1,13 @@ +CREATE TABLE foo (bar text not null, maybe_bar text); + +-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg(bar); + +-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg(maybe_bar); + +-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg('bar'); + +-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg('maybe_bar'); diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..36ef5f4f45 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..faee232b20 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + Bar string + MaybeBar sql.NullString +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..2939df932e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,119 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const identOnNonNullable = `-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) IdentOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, identOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const identOnNullable = `-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) IdentOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, identOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNonNullable = `-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) StringOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, stringOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNullable = `-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) StringOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, stringOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..634830cbdf --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/query.sql @@ -0,0 +1,13 @@ +CREATE TABLE foo (bar text not null, maybe_bar text); + +-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg(bar); + +-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg(maybe_bar); + +-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg('bar'); + +-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg('maybe_bar'); diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..de427d069f --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/source/code.go b/internal/source/code.go index b84324b55f..9a6ed077d3 100644 --- a/internal/source/code.go +++ b/internal/source/code.go @@ -54,25 +54,31 @@ func Mutate(raw string, a []Edit) (string, error) { if len(a) == 0 { return raw, nil } + sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location }) + s := raw - for _, edit := range a { + for idx, edit := range a { start := edit.Location - if start > len(s) { + if start > len(s) || start < 0 { return "", fmt.Errorf("edit start location is out of bounds") } - if len(edit.New) <= 0 { - return "", fmt.Errorf("empty edit contents") - } - if len(edit.Old) <= 0 { - return "", fmt.Errorf("empty edit contents") + + stop := edit.Location + len(edit.Old) + if stop > len(s) { + return "", fmt.Errorf("edit stop location is out of bounds") } - stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty - if stop < len(s) { - s = s[:start] + edit.New + s[stop+1:] - } else { - s = s[:start] + edit.New + + // If this is not the first edit, (applied backwards), check if + // this edit overlaps the previous one (and is therefore a developer error) + if idx != 0 { + prevEdit := a[idx-1] + if prevEdit.Location < edit.Location+len(edit.Old) { + return "", fmt.Errorf("2 edits overlap") + } } + + s = s[:start] + edit.New + s[stop:] } return s, nil } diff --git a/internal/source/mutate_test.go b/internal/source/mutate_test.go new file mode 100644 index 0000000000..dd76888796 --- /dev/null +++ b/internal/source/mutate_test.go @@ -0,0 +1,210 @@ +package source + +import ( + "fmt" + "testing" +) + +// newEdit is a testing helper for quickly generating Edits +func newEdit(loc int, old, new string) Edit { + return Edit{Location: loc, Old: old, New: new} +} + +// TestMutateSingle tests almost every possibility of a single edit +func TestMutateSingle(t *testing.T) { + type test struct { + input string + edit Edit + expected string + } + + tests := []test{ + // Simple edits that replace everything + {"", newEdit(0, "", ""), ""}, + {"a", newEdit(0, "a", "A"), "A"}, + {"abcde", newEdit(0, "abcde", "fghij"), "fghij"}, + {"", newEdit(0, "", "fghij"), "fghij"}, + {"abcde", newEdit(0, "abcde", ""), ""}, + + // Edits that start at the very beginning (But don't cover the whole range) + {"abcde", newEdit(0, "a", "A"), "Abcde"}, + {"abcde", newEdit(0, "ab", "AB"), "ABcde"}, + {"abcde", newEdit(0, "abc", "ABC"), "ABCde"}, + {"abcde", newEdit(0, "abcd", "ABCD"), "ABCDe"}, + + // The above repeated, but with different lengths + {"abcde", newEdit(0, "a", ""), "bcde"}, + {"abcde", newEdit(0, "ab", "A"), "Acde"}, + {"abcde", newEdit(0, "abc", "AB"), "ABde"}, + {"abcde", newEdit(0, "abcd", "AB"), "ABe"}, + + // Edits that touch the end (but don't cover the whole range) + {"abcde", newEdit(4, "e", "E"), "abcdE"}, + {"abcde", newEdit(3, "de", "DE"), "abcDE"}, + {"abcde", newEdit(2, "cde", "CDE"), "abCDE"}, + {"abcde", newEdit(1, "bcde", "BCDE"), "aBCDE"}, + + // The above repeated, but with different lengths + {"abcde", newEdit(4, "e", ""), "abcd"}, + {"abcde", newEdit(3, "de", "D"), "abcD"}, + {"abcde", newEdit(2, "cde", "CD"), "abCD"}, + {"abcde", newEdit(1, "bcde", "BC"), "aBC"}, + + // Raw insertions / deletions + {"abcde", newEdit(0, "", "_"), "_abcde"}, + {"abcde", newEdit(1, "", "_"), "a_bcde"}, + {"abcde", newEdit(2, "", "_"), "ab_cde"}, + {"abcde", newEdit(3, "", "_"), "abc_de"}, + {"abcde", newEdit(4, "", "_"), "abcd_e"}, + {"abcde", newEdit(5, "", "_"), "abcde_"}, + } + + origTests := tests + // Generate the reverse mutations, for every edit - the opposite edit that makes it "undo" + for _, spec := range origTests { + tests = append(tests, test{ + input: spec.expected, + edit: newEdit(spec.edit.Location, spec.edit.New, spec.edit.Old), + expected: spec.input, + }) + } + + for _, spec := range tests { + expected := spec.expected + + actual, err := Mutate(spec.input, []Edit{spec.edit}) + testName := fmt.Sprintf("Mutate(%s, Edit{%v, %v -> %v})", spec.input, spec.edit.Location, spec.edit.Old, spec.edit.New) + if err != nil { + t.Errorf("%s should not error (%v)", testName, err) + continue + } + + if actual != expected { + t.Errorf("%s expected %v; got %v", testName, expected, actual) + } + } +} + +// TestMutateMulti tests combinations of edits +func TestMutateMulti(t *testing.T) { + type test struct { + input string + edit1 Edit + edit2 Edit + expected string + } + + tests := []test{ + // Edits that are >1 character from each other + {"abcde", newEdit(0, "a", "A"), newEdit(2, "c", "C"), "AbCde"}, + {"abcde", newEdit(0, "a", "A"), newEdit(2, "c", "C"), "AbCde"}, + + // 2 edits bump right up next to each other + {"abcde", newEdit(0, "abc", ""), newEdit(3, "de", "DE"), "DE"}, + {"abcde", newEdit(0, "abc", "ABC"), newEdit(3, "de", ""), "ABC"}, + {"abcde", newEdit(0, "abc", "ABC"), newEdit(3, "de", "DE"), "ABCDE"}, + {"abcde", newEdit(1, "b", "BB"), newEdit(2, "c", "CC"), "aBBCCde"}, + + // 2 edits bump next to each other, but don't cover the whole string + {"abcdef", newEdit(1, "bc", "C"), newEdit(3, "de", "D"), "aCDf"}, + {"abcde", newEdit(1, "bc", "CCCC"), newEdit(3, "d", "DDD"), "aCCCCDDDe"}, + + // lengthening edits + {"abcde", newEdit(1, "b", "BBBB"), newEdit(2, "c", "CCCC"), "aBBBBCCCCde"}, + } + + origTests := tests + // Generate the edits in opposite order mutations, source edits should be independent of + // the order the edits are specified + for _, spec := range origTests { + tests = append(tests, test{ + input: spec.input, + edit1: spec.edit2, + edit2: spec.edit1, + expected: spec.expected, + }) + } + + for _, spec := range tests { + expected := spec.expected + + actual, err := Mutate(spec.input, []Edit{spec.edit1, spec.edit2}) + testName := fmt.Sprintf("Mutate(%s, Edits{(%v, %v -> %v), (%v, %v -> %v)})", spec.input, + spec.edit1.Location, spec.edit1.Old, spec.edit1.New, + spec.edit2.Location, spec.edit2.Old, spec.edit2.New) + + if err != nil { + t.Errorf("%s should not error (%v)", testName, err) + continue + } + + if actual != expected { + t.Errorf("%s expected %v; got %v", testName, expected, actual) + } + } +} + +// TestMutateErrorSingle test errors are generated for trivially incorrect single edits +func TestMutateErrorSingle(t *testing.T) { + type test struct { + input string + edit Edit + } + + tests := []test{ + // old text is longer than input text + {"", newEdit(0, "a", "A")}, + {"a", newEdit(0, "aa", "A")}, + {"hello", newEdit(0, "hello!", "A")}, + + // negative indexes + {"aaa", newEdit(-1, "aa", "A")}, + {"aaa", newEdit(-2, "aa", "A")}, + {"aaa", newEdit(-100, "aa", "A")}, + } + + for _, spec := range tests { + edit := spec.edit + + _, err := Mutate(spec.input, []Edit{edit}) + testName := fmt.Sprintf("Mutate(%s, Edit{%v, %v -> %v})", spec.input, edit.Location, edit.Old, edit.New) + if err == nil { + t.Errorf("%s should error (%v)", testName, err) + continue + } + } +} + +// TestMutateErrorMulti tests error that can only happen across multiple errors +func TestMutateErrorMulti(t *testing.T) { + type test struct { + input string + edit1 Edit + edit2 Edit + } + + tests := []test{ + // These edits overlap each other, and are therefore undefined + {"abcdef", newEdit(0, "a", ""), newEdit(0, "a", "A")}, + {"abcdef", newEdit(0, "ab", ""), newEdit(1, "ab", "AB")}, + {"abcdef", newEdit(0, "abc", ""), newEdit(2, "abc", "ABC")}, + + // the last edit is longer than the string itself + {"abcdef", newEdit(0, "abcdefghi", ""), newEdit(2, "abc", "ABC")}, + + // negative indexes + {"abcdef", newEdit(-1, "abc", ""), newEdit(3, "abc", "ABC")}, + {"abcdef", newEdit(0, "abc", ""), newEdit(-1, "abc", "ABC")}, + } + + for _, spec := range tests { + actual, err := Mutate(spec.input, []Edit{spec.edit1, spec.edit2}) + testName := fmt.Sprintf("Mutate(%s, Edits{(%v, %v -> %v), (%v, %v -> %v)})", spec.input, + spec.edit1.Location, spec.edit1.Old, spec.edit1.New, + spec.edit2.Location, spec.edit2.Old, spec.edit2.New) + + if err == nil { + t.Errorf("%s should error, but got (%v)", testName, actual) + } + } +} diff --git a/internal/sql/named/is.go b/internal/sql/named/is.go index 5421a85bb1..ba26c645d2 100644 --- a/internal/sql/named/is.go +++ b/internal/sql/named/is.go @@ -5,15 +5,19 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/astutils" ) +// IsParamFunc fulfills the astutils.Search func IsParamFunc(node ast.Node) bool { call, ok := node.(*ast.FuncCall) if !ok { return false } + if call.Func == nil { return false } - return call.Func.Schema == "sqlc" && call.Func.Name == "arg" + + isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "arg" || call.Func.Name == "narg") + return isValid } func IsParamSign(node ast.Node) bool { diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go new file mode 100644 index 0000000000..ec29e6184d --- /dev/null +++ b/internal/sql/named/param.go @@ -0,0 +1,114 @@ +package named + +// nullability represents the nullability of a named parameter. +// The nullability can be: +// 1. unspecified +// 2. inferred +// 3. user-defined +// A user-specified nullability carries a higher precedence than an inferred one +// +// The representation is such that you can bitwise OR together nullability types to +// combine them together. +type nullability int + +const ( + nullUnspecified nullability = 0b0000 + inferredNull nullability = 0b0001 + inferredNotNull nullability = 0b0010 + nullable nullability = 0b0100 + notNullable nullability = 0b1000 +) + +// String implements the Stringer interface +func (n nullability) String() string { + switch n { + case nullUnspecified: + return "NullUnspecified" + case inferredNull: + return "InferredNull" + case inferredNotNull: + return "InferredNotNull" + case nullable: + return "Nullable" + case notNullable: + return "NotNullable" + default: + return "NullInvalid" + } +} + +// Param represents a input argument to the query which can be specified using: +// - positional parameters $1 +// - named parameter operator @param +// - named parameter function calls sqlc.arg(param) +type Param struct { + name string + nullability nullability +} + +// NewParam builds a new params with unspecified nullability +func NewParam(name string) Param { + return Param{name: name, nullability: nullUnspecified} +} + +// NewInferredParam builds a new params with inferred nullability +func NewInferredParam(name string, notNull bool) Param { + if notNull { + return Param{name: name, nullability: inferredNotNull} + } + + return Param{name: name, nullability: inferredNull} +} + +// NewUserNullableParam is a parameter that has been overridden +// by the user to be nullable. +func NewUserNullableParam(name string) Param { + return Param{name: name, nullability: nullable} +} + +// Name is the user defined name to use for this parameter +func (p Param) Name() string { + return p.name +} + +// is checks if this params object has the specified nullability bit set +func (p Param) is(n nullability) bool { + return (p.nullability & n) == n +} + +// NonNull determines whether this param should be "not null" in its current state +func (p Param) NotNull() bool { + const null = false + const notNull = true + + if p.is(notNullable) { + return notNull + } + + if p.is(nullable) { + return null + } + + if p.is(inferredNotNull) { + return notNull + } + + if p.is(inferredNull) { + return null + } + + // This param is unspecified, so by default we choose nullable + // which matches the default behavior of most databases + return null +} + +// mergeParam creates a new param from 2 partially specified params +// If the parameters have different names, the first is preferred +func mergeParam(a, b Param) Param { + name := a.name + if name == "" { + name = b.name + } + + return Param{name: name, nullability: a.nullability | b.nullability} +} diff --git a/internal/sql/named/param_set.go b/internal/sql/named/param_set.go new file mode 100644 index 0000000000..b30de738b3 --- /dev/null +++ b/internal/sql/named/param_set.go @@ -0,0 +1,78 @@ +package named + +// ParamSet represents a set of parameters for a single query +type ParamSet struct { + // does this engine support named parameters? + hasNamedSupport bool + // the set of currently tracked named parameters + namedParams map[string]Param + // the locations of each of the named parameters + namedLocs map[string][]int + // a map of positions currently used + positionToName map[int]string + // argn keeps track of the last checked positional parameter used + argn int +} + +func (p *ParamSet) nextArgNum() int { + for { + if _, ok := p.positionToName[p.argn]; !ok { + return p.argn + } + + p.argn++ + } +} + +// Add adds a parameter to this set and returns the numbered location used for it +func (p *ParamSet) Add(param Param) int { + name := param.name + existing, ok := p.namedParams[name] + + p.namedParams[name] = mergeParam(existing, param) + if ok && p.hasNamedSupport { + return p.namedLocs[name][0] + } + + argn := p.nextArgNum() + p.positionToName[argn] = name + p.namedLocs[name] = append(p.namedLocs[name], argn) + return argn +} + +// FetchMerge fetches an indexed parameter, and merges `mergeP` into it +// Returns: the merged parameter and whether it was a named parameter +func (p *ParamSet) FetchMerge(idx int, mergeP Param) (param Param, isNamed bool) { + name, exists := p.positionToName[idx] + if !exists || name == "" { + return mergeP, false + } + + param, ok := p.namedParams[name] + if !ok { + return mergeP, false + } + + return mergeParam(param, mergeP), true +} + +// NewParamSet creates a set of parameters with the given list of already used positions +func NewParamSet(positionsUsed map[int]bool, hasNamedSupport bool) *ParamSet { + positionToName := make(map[int]string, len(positionsUsed)) + for index, used := range positionsUsed { + if !used { + continue + } + + // assume the previously used params have no name + positionToName[index] = "" + } + + return &ParamSet{ + argn: 1, + namedParams: make(map[string]Param), + namedLocs: make(map[string][]int), + hasNamedSupport: hasNamedSupport, + positionToName: positionToName, + } +} diff --git a/internal/sql/named/param_set_test.go b/internal/sql/named/param_set_test.go new file mode 100644 index 0000000000..99b7ed0575 --- /dev/null +++ b/internal/sql/named/param_set_test.go @@ -0,0 +1,58 @@ +package named + +import "testing" + +func TestParamSet_Add(t *testing.T) { + t.Parallel() + + type test struct { + pset *ParamSet + param Param + expected int + } + + named := NewParamSet(nil, true) + populatedNamed := NewParamSet(map[int]bool{1: true, 2: true, 4: true, 5: true, 6: true}, true) + populatedUnnamed := NewParamSet(map[int]bool{1: true, 2: true, 4: true, 5: true, 6: true}, false) + unnamed := NewParamSet(nil, false) + p1 := NewParam("hello") + p2 := NewParam("world") + + tests := []test{ + // First parameter should be 1 + {named, p1, 1}, + // Duplicate first parameters should be 1 + {named, p1, 1}, + // A new parameter receives a new parameter number + {named, p2, 2}, + // An additional new parameter does _not_ receive a new + {named, p2, 2}, + + // First parameter should be 1 + {unnamed, p1, 1}, + // Duplicate first parameters should increment argn + {unnamed, p1, 2}, + // A new parameter receives a new parameter number + {unnamed, p2, 3}, + // An additional new parameter still does receive a new argn + {unnamed, p2, 4}, + + // First parameter of a pre-populated should be 3 + {populatedNamed, p1, 3}, + {populatedNamed, p1, 3}, + {populatedNamed, p2, 7}, + {populatedNamed, p2, 7}, + + {populatedUnnamed, p1, 3}, + {populatedUnnamed, p1, 7}, + {populatedUnnamed, p2, 8}, + {populatedUnnamed, p2, 9}, + } + + for _, spec := range tests { + actual := spec.pset.Add(spec.param) + if actual != spec.expected { + t.Errorf("ParamSet.Add(%s) expected %v; got %v", spec.param.name, spec.expected, actual) + } + } +} diff --git a/internal/sql/named/param_test.go b/internal/sql/named/param_test.go new file mode 100644 index 0000000000..2643f8b308 --- /dev/null +++ b/internal/sql/named/param_test.go @@ -0,0 +1,82 @@ +package named + +import "testing" + +func TestMergeParamNullability(t *testing.T) { + type test struct { + a Param + b Param + notNull bool + message string + } + + name := "name" + unspec := NewParam(name) + inferredNotNull := NewInferredParam(name, true) + inferredNull := NewInferredParam(name, false) + userDefNull := NewUserNullableParam(name) + + const notNull = true + const null = false + + tests := []test{ + // Unspecified nullability parameter works + {unspec, inferredNotNull, notNull, "Unspec + inferred(not null) = not null"}, + {unspec, inferredNull, null, "Unspec + inferred(not null) = null"}, + {unspec, userDefNull, null, "Unspec + userdef(null) = null"}, + + // Inferred nullability agreeing with user defined nullabilty + {inferredNull, userDefNull, null, "inferred(null) + userdef(null) = null"}, + + // Inferred nullability disagreeing with user defined nullabilty + {inferredNotNull, userDefNull, null, "inferred(not null) + userdef(null) = null"}, + } + + for _, spec := range tests { + a := spec.a + b := spec.b + actual := mergeParam(a, b).NotNull() + expected := spec.notNull + if actual != expected { + t.Errorf("Combine(%s,%s) expected %v; got %v", a.nullability, b.nullability, expected, actual) + } + + // We have already tried Combine(a, b) the same result should be true for Combine(b, a) + actual = mergeParam(b, a).NotNull() + if actual != expected { + t.Errorf("Combine(%s,%s) expected %v; got %v", b.nullability, a.nullability, expected, actual) + } + } +} + +func TestMergeParamName(t *testing.T) { + type test struct { + a Param + b Param + name string + } + + a := NewParam("a") + b := NewParam("b") + blank := NewParam("") + + tests := []test{ + // should prefer the first param's name if both specified + {a, b, "a"}, + {b, a, "b"}, + + // should prefer non-blank names + {a, blank, "a"}, + {blank, a, "a"}, + } + + for _, spec := range tests { + a := spec.a + b := spec.b + actual := mergeParam(a, b).Name() + expected := spec.name + if actual != expected { + t.Errorf("Combine(%s,%s) expected %v; got %v", a.name, b.name, expected, actual) + } + } +} diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index b9ba52001e..250d967e76 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,59 +41,63 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]string, []source.Edit) { +// paramFromFuncCall creates a param from sqlc.n?arg() calls return the +// parameter and whether the parameter name was specified a best guess as its +// "source" string representation (used for replacing this function call in the +// original SQL query) +func paramFromFuncCall(call *ast.FuncCall) (named.Param, string) { + paramName, isConst := flatten(call.Args) + + // origName keeps track of how the parameter was specified in the source SQL + origName := paramName + if isConst { + origName = fmt.Sprintf("'%s'", paramName) + } + + param := named.NewParam(paramName) + if call.Func.Name == "narg" { + param = named.NewUserNullableParam(paramName) + } + + // TODO: This code assumes that sqlc.arg(name) / sqlc.narg(name) is on a single line + // with no extraneous spaces (or any non-significant tokens for that matter) + origText := fmt.Sprintf("%s.%s(%s)", call.Func.Schema, call.Func.Name, origName) + return param, origText +} + +func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, *named.ParamSet, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) + hasNamedParameterSupport := engine != config.EngineMySQL + allParams := named.NewParamSet(numbs, hasNamedParameterSupport) + if len(foundFunc.Items)+len(foundSign.Items) == 0 { - return raw, map[int]string{}, nil + return raw, allParams, nil } - hasNamedParameterSupport := engine != config.EngineMySQL - - args := map[string][]int{} - argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { node := cr.Node() switch { case named.IsParamFunc(node): fun := node.(*ast.FuncCall) - param, isConst := flatten(fun.Args) - if nums, ok := args[param]; ok && hasNamedParameterSupport { - cr.Replace(&ast.ParamRef{ - Number: nums[0], - Location: fun.Location, - }) - } else { - argn++ - for numbs[argn] { - argn++ - } - if _, found := args[param]; !found { - args[param] = []int{argn} - } else { - args[param] = append(args[param], argn) - } - cr.Replace(&ast.ParamRef{ - Number: argn, - Location: fun.Location, - }) - } - // TODO: This code assumes that sqlc.arg(name) is on a single line - var old, replace string - if isConst { - old = fmt.Sprintf("sqlc.arg('%s')", param) - } else { - old = fmt.Sprintf("sqlc.arg(%s)", param) - } + param, origText := paramFromFuncCall(fun) + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: fun.Location, + }) + + var replace string if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param][0]) + replace = fmt.Sprintf("$%d", argn) } + edits = append(edits, source.Edit{ Location: fun.Location - raw.StmtLocation, - Old: old, + Old: origText, New: replace, }) return false @@ -101,76 +105,53 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case isNamedParamSignCast(node): expr := node.(*ast.A_Expr) cast := expr.Rexpr.(*ast.TypeCast) - param, _ := flatten(cast.Arg) - if nums, ok := args[param]; ok { - cast.Arg = &ast.ParamRef{ - Number: nums[0], - Location: expr.Location, - } - cr.Replace(cast) - } else { - argn++ - for numbs[argn] { - argn++ - } - if _, found := args[param]; !found { - args[param] = []int{argn} - } else { - args[param] = append(args[param], argn) - } - cast.Arg = &ast.ParamRef{ - Number: argn, - Location: expr.Location, - } - cr.Replace(cast) + paramName, _ := flatten(cast.Arg) + param := named.NewParam(paramName) + + argn := allParams.Add(param) + cast.Arg = &ast.ParamRef{ + Number: argn, + Location: expr.Location, } + cr.Replace(cast) + // TODO: This code assumes that @foo::bool is on a single line var replace string if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param][0]) + replace = fmt.Sprintf("$%d", argn) } + edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, - Old: fmt.Sprintf("@%s", param), + Old: fmt.Sprintf("@%s", paramName), New: replace, }) return false case named.IsParamSign(node): expr := node.(*ast.A_Expr) - param, _ := flatten(expr.Rexpr) - if nums, ok := args[param]; ok { - cr.Replace(&ast.ParamRef{ - Number: nums[0], - Location: expr.Location, - }) - } else { - argn++ - for numbs[argn] { - argn++ - } - if _, found := args[param]; !found { - args[param] = []int{argn} - } else { - args[param] = append(args[param], argn) - } - cr.Replace(&ast.ParamRef{ - Number: argn, - Location: expr.Location, - }) - } + paramName, _ := flatten(expr.Rexpr) + param := named.NewParam(paramName) + + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: expr.Location, + }) + // TODO: This code assumes that @foo is on a single line var replace string if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param][0]) + replace = fmt.Sprintf("$%d", argn) } + edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, - Old: fmt.Sprintf("@%s", param), + Old: fmt.Sprintf("@%s", paramName), New: replace, }) return false @@ -180,11 +161,5 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, } }, nil) - named := map[int]string{} - for k, vs := range args { - for _, v := range vs { - named[v] = k - } - } - return node.(*ast.RawStmt), named, edits + return node.(*ast.RawStmt), allParams, edits } diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index 85c3df0d7e..5fbac048d2 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -34,7 +34,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { // Custom validation for sqlc.arg // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if fn.Name != "arg" { + if !(fn.Name == "arg" || fn.Name == "narg") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index fbec8f9066..170a158527 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -3,6 +3,7 @@ package validate import ( "errors" "fmt" + "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/sqlerr"