From 8ee0ea8d23183695f38546382109db8f88b2aebd Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Thu, 7 Apr 2022 00:47:44 -0700 Subject: [PATCH 01/11] refactor named parameters to use specific type --- internal/compiler/parse.go | 1 + internal/compiler/resolve.go | 82 +++++++++++++++++++----------- internal/sql/named/param.go | 79 ++++++++++++++++++++++++++++ internal/sql/rewrite/parameters.go | 10 ++-- 4 files changed, 137 insertions(+), 35 deletions(-) create mode 100644 internal/sql/named/param.go diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 017a326797..cc54036d1d 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -107,6 +107,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err != nil { return nil, err } + params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams) if err != nil { return nil, err diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 4551e26425..173a84aa80 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -7,6 +7,7 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/catalog" + "github.com/kyleconroy/sqlc/internal/sql/named" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" ) @@ -18,7 +19,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, names map[int]string) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params map[int]named.Param) ([]Parameter, error) { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -26,16 +27,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, var defaultTable *ast.TableName var tables []*ast.TableName - parameterName := func(n int, defaultName string) string { - if n, ok := names[n]; ok { - return n - } - return defaultName - } - - isNamedParam := func(n int) bool { - _, ok := names[n] - return ok + // fetch param fetches the named parameter at index `n` or a default substitution + // and returns whether it was found or not. + fetchParam := func(n int, defaultP named.Param) (named.Param, bool) { + p, ok := params[n] + p = named.Combine(p, defaultP) + return p, ok } typeMap := map[string]map[string]map[string]*catalog.Column{} @@ -92,24 +89,28 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, switch n := ref.parent.(type) { case *limitOffset: + defaultP := named.NewParam("offset", true) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, "offset"), + Name: p.Name(), DataType: "integer", NotNull: true, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) case *limitCount: + defaultP := named.NewParam("limit", true) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, "limit"), + Name: p.Name(), DataType: "integer", NotNull: true, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) @@ -127,12 +128,15 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if astutils.Join(n.Name, ".") == "||" { dataType = "string" } + + defaultP := named.NewUnspecifiedParam("") + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, ""), + Name: p.Name(), DataType: dataType, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) continue @@ -185,16 +189,19 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if ref.name != "" { key = ref.name } + + defaultP := named.NewParam(key, c.IsNotNull) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), NotNull: c.IsNotNull, IsArray: c.IsArray, Length: c.Length, Table: table, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } @@ -242,15 +249,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } if c, ok := typeMap[schema][table.Name][key]; ok { + defaultP := named.NewParam(key, c.IsNotNull) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), NotNull: c.IsNotNull, IsArray: c.IsArray, Table: table, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } @@ -309,12 +318,15 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if argName != "" { defaultName = argName } + + defaultP := named.NewParam(defaultName, false) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, defaultName), + Name: p.Name(), DataType: "any", - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) continue @@ -340,13 +352,15 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, paramName = funcName } + defaultP := named.NewParam(paramName, true) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, paramName), + Name: p.Name(), DataType: dataType(paramType), NotNull: true, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } @@ -399,16 +413,18 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } if c, ok := tableMap[key]; ok { + defaultP := named.NewParam(key, c.IsNotNull) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), NotNull: c.IsNotNull, IsArray: c.IsArray, Table: &ast.TableName{Schema: schema, Name: rel}, Length: c.Length, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } else { @@ -424,7 +440,11 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, return nil, fmt.Errorf("*ast.TypeCast has nil type name") } col := toColumn(n.TypeName) - col.Name = parameterName(ref.ref.Number, col.Name) + defaultP := named.NewParam(col.Name, col.NotNull) + p, _ := fetchParam(ref.ref.Number, defaultP) + + col.Name = p.Name() + col.NotNull = p.NotNull() a = append(a, Parameter{ Number: ref.ref.Number, Column: col, @@ -500,15 +520,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if ref.name != "" { key = ref.name } + defaultP := named.NewParam(key, c.IsNotNull) + p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, Column: &Column{ - Name: parameterName(ref.ref.Number, key), + Name: p.Name(), DataType: dataType(&c.Type), NotNull: c.IsNotNull, IsArray: c.IsArray, Table: table, - IsNamedParam: isNamedParam(ref.ref.Number), + IsNamedParam: isNamed, }, }) } diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go new file mode 100644 index 0000000000..4afff50c7f --- /dev/null +++ b/internal/sql/named/param.go @@ -0,0 +1,79 @@ +package named + +// Nullability represents the nullability of a named parameter The +// representation is such that you can bitwise OR together Nullability types to +// combine them +// For example: +// - NullUnspecified | Nullable = Nullable +// - NonNullable | Nullable = NullInvalid +type Nullability int + +const ( + NullUnspecified Nullability = 0b00 + Nullable Nullability = 0b01 + NotNullable Nullability = 0b10 + NullInvalid Nullability = 0b11 +) + +// String implements the Stringer interface +func (n Nullability) String() string { + switch n { + case NullUnspecified: + return "NullUnspecified" + case Nullable: + return "Nullable" + case NotNullable: + return "NotNullable" + default: + return "NullInvalid" + } +} + +// Param represents a input argument to the query which can be specified using: +// - positional parameters $1 +// - named parameter operator @param +// - named parameter function calls sqlc.arg(param) +type Param struct { + name string + nullability Nullability +} + +// NewUnspecifiedParam builds a new params with unspecified nullability +func NewUnspecifiedParam(name string) Param { + return Param{name: name, nullability: NullUnspecified} +} + +// NewParam creates a new named param with the given nullability +func NewParam(name string, notNull bool) Param { + if notNull { + return Param{name: name, nullability: NotNullable} + } + + return Param{name: name, nullability: Nullable} +} + +// Name is the user defined name to use for this parameter +func (p Param) Name() string { + return p.name +} + +// Nullability retrieves the nullability status of this param +func (p Param) Nullability() Nullability { + return p.nullability +} + +// NonNull determines whether this param is NonNull +func (p Param) NotNull() bool { + return (p.nullability & NotNullable) > 0 +} + +// Combine creates a new param from 2 partially specified params +// If the parameters have different names, the first is preferred +func Combine(a, b Param) Param { + name := a.name + if name == "" { + name = b.name + } + + return Param{name: name, nullability: a.nullability | b.nullability} +} diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index b9ba52001e..2d645d299b 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,11 +41,11 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]string, []source.Edit) { +func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]named.Param, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) if len(foundFunc.Items)+len(foundSign.Items) == 0 { - return raw, map[int]string{}, nil + return raw, map[int]named.Param{}, nil } hasNamedParameterSupport := engine != config.EngineMySQL @@ -180,11 +180,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, } }, nil) - named := map[int]string{} + paramByLoc := map[int]named.Param{} for k, vs := range args { for _, v := range vs { - named[v] = k + paramByLoc[v] = named.NewUnspecifiedParam(k) } } - return node.(*ast.RawStmt), named, edits + return node.(*ast.RawStmt), paramByLoc, edits } From 48598af749d6fe9276a1f7ac41eb1bbf5506741a Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Thu, 7 Apr 2022 01:30:56 -0700 Subject: [PATCH 02/11] refactor rewrite.NamedParameters to use types --- internal/sql/rewrite/parameters.go | 82 +++++++++++++++++------------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 2d645d299b..824ef7ee2d 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,6 +41,17 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } +type namedParameter struct { + param named.Param + locs []int +} + +// Add a new instance of this parameter of the same name +func (n *namedParameter) AddInstance(loc int, p named.Param) { + n.param = named.Combine(n.param, p) + n.locs = append(n.locs, loc) +} + func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]named.Param, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) @@ -50,7 +61,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, hasNamedParameterSupport := engine != config.EngineMySQL - args := map[string][]int{} + args := map[string]namedParameter{} argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { @@ -58,22 +69,23 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, switch { case named.IsParamFunc(node): fun := node.(*ast.FuncCall) - param, isConst := flatten(fun.Args) - if nums, ok := args[param]; ok && hasNamedParameterSupport { + paramName, isConst := flatten(fun.Args) + if namedP, ok := args[paramName]; ok && hasNamedParameterSupport { cr.Replace(&ast.ParamRef{ - Number: nums[0], + Number: namedP.locs[0], Location: fun.Location, }) } else { + // Find the arg number that has not yet been used argn++ for numbs[argn] { argn++ } - if _, found := args[param]; !found { - args[param] = []int{argn} - } else { - args[param] = append(args[param], argn) - } + + // keep track of the locations this argument is present + p := args[paramName] + p.AddInstance(argn, named.NewUnspecifiedParam(paramName)) + args[paramName] = p cr.Replace(&ast.ParamRef{ Number: argn, Location: fun.Location, @@ -82,14 +94,14 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, // TODO: This code assumes that sqlc.arg(name) is on a single line var old, replace string if isConst { - old = fmt.Sprintf("sqlc.arg('%s')", param) + old = fmt.Sprintf("sqlc.arg('%s')", paramName) } else { - old = fmt.Sprintf("sqlc.arg(%s)", param) + old = fmt.Sprintf("sqlc.arg(%s)", paramName) } if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param][0]) + replace = fmt.Sprintf("$%d", args[paramName].locs[0]) } edits = append(edits, source.Edit{ Location: fun.Location - raw.StmtLocation, @@ -101,10 +113,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case isNamedParamSignCast(node): expr := node.(*ast.A_Expr) cast := expr.Rexpr.(*ast.TypeCast) - param, _ := flatten(cast.Arg) - if nums, ok := args[param]; ok { + paramName, _ := flatten(cast.Arg) + if p, ok := args[paramName]; ok { cast.Arg = &ast.ParamRef{ - Number: nums[0], + Number: p.locs[0], Location: expr.Location, } cr.Replace(cast) @@ -113,11 +125,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, for numbs[argn] { argn++ } - if _, found := args[param]; !found { - args[param] = []int{argn} - } else { - args[param] = append(args[param], argn) - } + + p := args[paramName] + p.AddInstance(argn, named.NewUnspecifiedParam(paramName)) + args[paramName] = p + cast.Arg = &ast.ParamRef{ Number: argn, Location: expr.Location, @@ -129,21 +141,21 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param][0]) + replace = fmt.Sprintf("$%d", args[paramName].locs[0]) } edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, - Old: fmt.Sprintf("@%s", param), + Old: fmt.Sprintf("@%s", paramName), New: replace, }) return false case named.IsParamSign(node): expr := node.(*ast.A_Expr) - param, _ := flatten(expr.Rexpr) - if nums, ok := args[param]; ok { + paramName, _ := flatten(expr.Rexpr) + if p, ok := args[paramName]; ok { cr.Replace(&ast.ParamRef{ - Number: nums[0], + Number: p.locs[0], Location: expr.Location, }) } else { @@ -151,11 +163,10 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, for numbs[argn] { argn++ } - if _, found := args[param]; !found { - args[param] = []int{argn} - } else { - args[param] = append(args[param], argn) - } + + p := args[paramName] + p.AddInstance(argn, named.NewUnspecifiedParam(paramName)) + args[paramName] = p cr.Replace(&ast.ParamRef{ Number: argn, Location: expr.Location, @@ -166,11 +177,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param][0]) + replace = fmt.Sprintf("$%d", args[paramName].locs[0]) } edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, - Old: fmt.Sprintf("@%s", param), + Old: fmt.Sprintf("@%s", paramName), New: replace, }) return false @@ -181,10 +192,11 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, }, nil) paramByLoc := map[int]named.Param{} - for k, vs := range args { - for _, v := range vs { - paramByLoc[v] = named.NewUnspecifiedParam(k) + for _, namedParam := range args { + for _, loc := range namedParam.locs { + paramByLoc[loc] = namedParam.param } } + return node.(*ast.RawStmt), paramByLoc, edits } From 5095f2e9d72566ac970f7363e1ef5b6f08d1a7fe Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Fri, 8 Apr 2022 18:22:51 -0700 Subject: [PATCH 03/11] add "inferred null" vs "user defined null" --- internal/compiler/resolve.go | 32 +++++++------ internal/sql/named/param.go | 73 +++++++++++++++++++++++------- internal/sql/rewrite/parameters.go | 46 +++++++++++-------- 3 files changed, 99 insertions(+), 52 deletions(-) diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 173a84aa80..d9225ba7a5 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -89,27 +89,27 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, switch n := ref.parent.(type) { case *limitOffset: - defaultP := named.NewParam("offset", true) + defaultP := named.NewInferredParam("offset", true) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), DataType: "integer", - NotNull: true, + NotNull: p.NotNull(), IsNamedParam: isNamed, }, }) case *limitCount: - defaultP := named.NewParam("limit", true) + defaultP := named.NewInferredParam("limit", true) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), DataType: "integer", - NotNull: true, + NotNull: p.NotNull(), IsNamedParam: isNamed, }, }) @@ -137,6 +137,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, Name: p.Name(), DataType: dataType, IsNamedParam: isNamed, + NotNull: p.NotNull(), }, }) continue @@ -190,14 +191,14 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, key = ref.name } - defaultP := named.NewParam(key, c.IsNotNull) + defaultP := named.NewInferredParam(key, c.IsNotNull) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), DataType: dataType(&c.Type), - NotNull: c.IsNotNull, + NotNull: p.NotNull(), IsArray: c.IsArray, Length: c.Length, Table: table, @@ -249,14 +250,14 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } if c, ok := typeMap[schema][table.Name][key]; ok { - defaultP := named.NewParam(key, c.IsNotNull) + defaultP := named.NewInferredParam(key, c.IsNotNull) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, Column: &Column{ Name: p.Name(), DataType: dataType(&c.Type), - NotNull: c.IsNotNull, + NotNull: p.NotNull(), IsArray: c.IsArray, Table: table, IsNamedParam: isNamed, @@ -319,7 +320,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, defaultName = argName } - defaultP := named.NewParam(defaultName, false) + defaultP := named.NewInferredParam(defaultName, false) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, @@ -327,6 +328,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, Name: p.Name(), DataType: "any", IsNamedParam: isNamed, + NotNull: p.NotNull(), }, }) continue @@ -352,14 +354,14 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, paramName = funcName } - defaultP := named.NewParam(paramName, true) + defaultP := named.NewInferredParam(paramName, true) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), DataType: dataType(paramType), - NotNull: true, + NotNull: p.NotNull(), IsNamedParam: isNamed, }, }) @@ -413,14 +415,14 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } if c, ok := tableMap[key]; ok { - defaultP := named.NewParam(key, c.IsNotNull) + defaultP := named.NewInferredParam(key, c.IsNotNull) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ Name: p.Name(), DataType: dataType(&c.Type), - NotNull: c.IsNotNull, + NotNull: p.NotNull(), IsArray: c.IsArray, Table: &ast.TableName{Schema: schema, Name: rel}, Length: c.Length, @@ -440,7 +442,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, return nil, fmt.Errorf("*ast.TypeCast has nil type name") } col := toColumn(n.TypeName) - defaultP := named.NewParam(col.Name, col.NotNull) + defaultP := named.NewInferredParam(col.Name, col.NotNull) p, _ := fetchParam(ref.ref.Number, defaultP) col.Name = p.Name() @@ -520,7 +522,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if ref.name != "" { key = ref.name } - defaultP := named.NewParam(key, c.IsNotNull) + defaultP := named.NewInferredParam(key, c.IsNotNull) p, isNamed := fetchParam(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go index 4afff50c7f..d6c5e0d355 100644 --- a/internal/sql/named/param.go +++ b/internal/sql/named/param.go @@ -1,18 +1,22 @@ package named -// Nullability represents the nullability of a named parameter The -// representation is such that you can bitwise OR together Nullability types to -// combine them -// For example: -// - NullUnspecified | Nullable = Nullable -// - NonNullable | Nullable = NullInvalid +// Nullability represents the nullability of a named parameter. +// The nullability can be: +// 1. unspecified +// 2. inferred +// 3. user-defined +// A user-specified nullability carries a higher precedence than an inferred one +// +// The representation is such that you can bitwise OR together Nullability types to +// combine them together. type Nullability int const ( - NullUnspecified Nullability = 0b00 - Nullable Nullability = 0b01 - NotNullable Nullability = 0b10 - NullInvalid Nullability = 0b11 + NullUnspecified Nullability = 0b0000 + InferredNull Nullability = 0b0001 + InferredNotNull Nullability = 0b0010 + Nullable Nullability = 0b0100 + NotNullable Nullability = 0b1000 ) // String implements the Stringer interface @@ -20,6 +24,10 @@ func (n Nullability) String() string { switch n { case NullUnspecified: return "NullUnspecified" + case InferredNull: + return "InferredNull" + case InferredNotNull: + return "InferredNotNull" case Nullable: return "Nullable" case NotNullable: @@ -43,8 +51,18 @@ func NewUnspecifiedParam(name string) Param { return Param{name: name, nullability: NullUnspecified} } -// NewParam creates a new named param with the given nullability -func NewParam(name string, notNull bool) Param { +// NewInferredParam builds a new params with inferred nullability +func NewInferredParam(name string, notNull bool) Param { + if notNull { + return Param{name: name, nullability: InferredNotNull} + } + + return Param{name: name, nullability: InferredNull} +} + +// NewUserDefinedParam creates a new param with the user specified +// by the end user +func NewUserDefinedParam(name string, notNull bool) Param { if notNull { return Param{name: name, nullability: NotNullable} } @@ -57,14 +75,35 @@ func (p Param) Name() string { return p.name } -// Nullability retrieves the nullability status of this param -func (p Param) Nullability() Nullability { - return p.nullability +// is checks if this params object has the specified nullability bit set +func (p Param) is(n Nullability) bool { + return (p.nullability & n) == n } -// NonNull determines whether this param is NonNull +// NonNull determines whether this param should be "not null" in its current state func (p Param) NotNull() bool { - return (p.nullability & NotNullable) > 0 + const nullable = false + const notNull = true + + if p.is(NotNullable) { + return notNull + } + + if p.is(Nullable) { + return nullable + } + + if p.is(InferredNotNull) { + return notNull + } + + if p.is(InferredNull) { + return nullable + } + + // This param is unspecified, so by default we choose nullable + // which matches the default behavior of most databases + return nullable } // Combine creates a new param from 2 partially specified params diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 824ef7ee2d..aa4f2a7797 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -47,9 +47,18 @@ type namedParameter struct { } // Add a new instance of this parameter of the same name -func (n *namedParameter) AddInstance(loc int, p named.Param) { - n.param = named.Combine(n.param, p) - n.locs = append(n.locs, loc) +func (n namedParameter) AddInstance(loc int, p named.Param) namedParameter { + param := named.Combine(n.param, p) + locs := append(n.locs, loc) + return namedParameter{ + param: param, + locs: locs, + } +} + +// paramFromName takes a user-defined parameter name and builds the appropiate parameter +func paramFromName(name string) named.Param { + return named.NewUnspecifiedParam(name) } func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]named.Param, []source.Edit) { @@ -70,7 +79,9 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamFunc(node): fun := node.(*ast.FuncCall) paramName, isConst := flatten(fun.Args) - if namedP, ok := args[paramName]; ok && hasNamedParameterSupport { + param := paramFromName(paramName) + + if namedP, ok := args[param.Name()]; ok && hasNamedParameterSupport { cr.Replace(&ast.ParamRef{ Number: namedP.locs[0], Location: fun.Location, @@ -82,10 +93,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, argn++ } - // keep track of the locations this argument is present - p := args[paramName] - p.AddInstance(argn, named.NewUnspecifiedParam(paramName)) - args[paramName] = p + args[param.Name()] = args[param.Name()].AddInstance(argn, param) cr.Replace(&ast.ParamRef{ Number: argn, Location: fun.Location, @@ -101,7 +109,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[paramName].locs[0]) + replace = fmt.Sprintf("$%d", args[param.Name()].locs[0]) } edits = append(edits, source.Edit{ Location: fun.Location - raw.StmtLocation, @@ -114,7 +122,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, expr := node.(*ast.A_Expr) cast := expr.Rexpr.(*ast.TypeCast) paramName, _ := flatten(cast.Arg) - if p, ok := args[paramName]; ok { + param := paramFromName(paramName) + if p, ok := args[param.Name()]; ok { cast.Arg = &ast.ParamRef{ Number: p.locs[0], Location: expr.Location, @@ -126,10 +135,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, argn++ } - p := args[paramName] - p.AddInstance(argn, named.NewUnspecifiedParam(paramName)) - args[paramName] = p - + args[param.Name()] = args[param.Name()].AddInstance(argn, param) cast.Arg = &ast.ParamRef{ Number: argn, Location: expr.Location, @@ -141,7 +147,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[paramName].locs[0]) + replace = fmt.Sprintf("$%d", args[param.Name()].locs[0]) } edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, @@ -153,7 +159,8 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamSign(node): expr := node.(*ast.A_Expr) paramName, _ := flatten(expr.Rexpr) - if p, ok := args[paramName]; ok { + param := paramFromName(paramName) + if p, ok := args[param.Name()]; ok { cr.Replace(&ast.ParamRef{ Number: p.locs[0], Location: expr.Location, @@ -164,9 +171,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, argn++ } - p := args[paramName] - p.AddInstance(argn, named.NewUnspecifiedParam(paramName)) - args[paramName] = p + args[param.Name()] = args[param.Name()].AddInstance(argn, param) cr.Replace(&ast.ParamRef{ Number: argn, Location: expr.Location, @@ -177,8 +182,9 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[paramName].locs[0]) + replace = fmt.Sprintf("$%d", args[param.Name()].locs[0]) } + edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, Old: fmt.Sprintf("@%s", paramName), From 0d480bcad6c2ad792e01139249b0854b05dc5f64 Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Fri, 8 Apr 2022 22:54:27 -0700 Subject: [PATCH 04/11] test: add tests for named.Param --- internal/sql/named/param.go | 58 ++++++++++----------- internal/sql/named/param_test.go | 86 ++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 29 deletions(-) create mode 100644 internal/sql/named/param_test.go diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go index d6c5e0d355..fcebd65814 100644 --- a/internal/sql/named/param.go +++ b/internal/sql/named/param.go @@ -1,36 +1,36 @@ package named -// Nullability represents the nullability of a named parameter. +// nullability represents the nullability of a named parameter. // The nullability can be: // 1. unspecified // 2. inferred // 3. user-defined // A user-specified nullability carries a higher precedence than an inferred one // -// The representation is such that you can bitwise OR together Nullability types to +// The representation is such that you can bitwise OR together nullability types to // combine them together. -type Nullability int +type nullability int const ( - NullUnspecified Nullability = 0b0000 - InferredNull Nullability = 0b0001 - InferredNotNull Nullability = 0b0010 - Nullable Nullability = 0b0100 - NotNullable Nullability = 0b1000 + nullUnspecified nullability = 0b0000 + inferredNull nullability = 0b0001 + inferredNotNull nullability = 0b0010 + nullable nullability = 0b0100 + notNullable nullability = 0b1000 ) // String implements the Stringer interface -func (n Nullability) String() string { +func (n nullability) String() string { switch n { - case NullUnspecified: + case nullUnspecified: return "NullUnspecified" - case InferredNull: + case inferredNull: return "InferredNull" - case InferredNotNull: + case inferredNotNull: return "InferredNotNull" - case Nullable: + case nullable: return "Nullable" - case NotNullable: + case notNullable: return "NotNullable" default: return "NullInvalid" @@ -43,31 +43,31 @@ func (n Nullability) String() string { // - named parameter function calls sqlc.arg(param) type Param struct { name string - nullability Nullability + nullability nullability } // NewUnspecifiedParam builds a new params with unspecified nullability func NewUnspecifiedParam(name string) Param { - return Param{name: name, nullability: NullUnspecified} + return Param{name: name, nullability: nullUnspecified} } // NewInferredParam builds a new params with inferred nullability func NewInferredParam(name string, notNull bool) Param { if notNull { - return Param{name: name, nullability: InferredNotNull} + return Param{name: name, nullability: inferredNotNull} } - return Param{name: name, nullability: InferredNull} + return Param{name: name, nullability: inferredNull} } // NewUserDefinedParam creates a new param with the user specified // by the end user func NewUserDefinedParam(name string, notNull bool) Param { if notNull { - return Param{name: name, nullability: NotNullable} + return Param{name: name, nullability: notNullable} } - return Param{name: name, nullability: Nullable} + return Param{name: name, nullability: nullable} } // Name is the user defined name to use for this parameter @@ -76,34 +76,34 @@ func (p Param) Name() string { } // is checks if this params object has the specified nullability bit set -func (p Param) is(n Nullability) bool { +func (p Param) is(n nullability) bool { return (p.nullability & n) == n } // NonNull determines whether this param should be "not null" in its current state func (p Param) NotNull() bool { - const nullable = false + const null = false const notNull = true - if p.is(NotNullable) { + if p.is(notNullable) { return notNull } - if p.is(Nullable) { - return nullable + if p.is(nullable) { + return null } - if p.is(InferredNotNull) { + if p.is(inferredNotNull) { return notNull } - if p.is(InferredNull) { - return nullable + if p.is(inferredNull) { + return null } // This param is unspecified, so by default we choose nullable // which matches the default behavior of most databases - return nullable + return null } // Combine creates a new param from 2 partially specified params diff --git a/internal/sql/named/param_test.go b/internal/sql/named/param_test.go new file mode 100644 index 0000000000..49748f6d70 --- /dev/null +++ b/internal/sql/named/param_test.go @@ -0,0 +1,86 @@ +package named + +import "testing" + +func TestCombineNullability(t *testing.T) { + type test struct { + a Param + b Param + notNull bool + message string + } + + name := "name" + unspec := NewUnspecifiedParam(name) + inferredNotNull := NewInferredParam(name, true) + inferredNull := NewInferredParam(name, false) + userDefNotNull := NewUserDefinedParam(name, true) + userDefNull := NewUserDefinedParam(name, false) + + const notNull = true + const null = false + + tests := []test{ + // Unspecified nullability parameter works + {unspec, inferredNotNull, notNull, "Unspec + inferred(not null) = not null"}, + {unspec, inferredNull, null, "Unspec + inferred(not null) = null"}, + {unspec, userDefNotNull, notNull, "Unspec + userdef(not null) = not null"}, + {unspec, userDefNull, null, "Unspec + userdef(null) = null"}, + + // Inferred nullability agreeing with user defined nullabilty + {inferredNotNull, userDefNotNull, notNull, "inferred(not null) + userdef(not null) = not null"}, + {inferredNull, userDefNull, null, "inferred(null) + userdef(null) = null"}, + + // Inferred nullability disagreeing with user defined nullabilty + {inferredNotNull, userDefNull, null, "inferred(not null) + userdef(null) = null"}, + {inferredNull, userDefNotNull, notNull, "inferred(null) + userdef(not null) = not null"}, + } + + for _, spec := range tests { + a := spec.a + b := spec.b + actual := Combine(a, b).NotNull() + expected := spec.notNull + if actual != expected { + t.Errorf("Combine(%s,%s) expected %v; got %v", a.nullability, b.nullability, expected, actual) + } + + // We have already tried Combine(a, b) the same result should be true for Combine(b, a) + actual = Combine(b, a).NotNull() + if actual != expected { + t.Errorf("Combine(%s,%s) expected %v; got %v", b.nullability, a.nullability, expected, actual) + } + } +} + +func TestCombineName(t *testing.T) { + type test struct { + a Param + b Param + name string + } + + a := NewUnspecifiedParam("a") + b := NewUnspecifiedParam("b") + blank := NewUnspecifiedParam("") + + tests := []test{ + // should prefer the first param's name if both specified + {a, b, "a"}, + {b, a, "b"}, + + // should prefer non-blank names + {a, blank, "a"}, + {blank, a, "a"}, + } + + for _, spec := range tests { + a := spec.a + b := spec.b + actual := Combine(a, b).Name() + expected := spec.name + if actual != expected { + t.Errorf("Combine(%s,%s) expected %v; got %v", a.name, b.name, expected, actual) + } + } +} From 20be86b55944f928883bf97d7e983c6056911610 Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Fri, 8 Apr 2022 18:49:35 -0700 Subject: [PATCH 05/11] implement nullability overrides with `!` and `?` suffixes --- internal/sql/rewrite/parameters.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index aa4f2a7797..064fa08ffc 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -56,8 +56,22 @@ func (n namedParameter) AddInstance(loc int, p named.Param) namedParameter { } } -// paramFromName takes a user-defined parameter name and builds the appropiate parameter +// paramFromName takes a user-defined parameter name, with an optional suffix of +// ? (nullable), or ! (non-null) and builds the appropiate parameter func paramFromName(name string) named.Param { + if len(name) == 0 { + return named.NewUnspecifiedParam(name) + } + + last := name[len(name)-1] + if last == '!' { + return named.NewUserDefinedParam(name[:len(name)-1], true) + } + + if last == '?' { + return named.NewUserDefinedParam(name[:len(name)-1], false) + } + return named.NewUnspecifiedParam(name) } From 2c5239e579cc807cefb00289f9dc668826aa8d28 Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Fri, 8 Apr 2022 22:04:23 -0700 Subject: [PATCH 06/11] test: add tests for optional and required parameters --- .../testdata/sqlc_arg/mysql/go/models.go | 7 +++++-- .../testdata/sqlc_arg/mysql/go/query.sql.go | 19 +++++++++++++++++++ .../testdata/sqlc_arg/mysql/query.sql | 8 +++++++- .../sqlc_arg/postgresql/pgx/go/models.go | 7 +++++-- .../sqlc_arg/postgresql/pgx/go/query.sql.go | 19 +++++++++++++++++++ .../sqlc_arg/postgresql/pgx/query.sql | 8 +++++++- .../sqlc_arg/postgresql/stdlib/go/models.go | 7 +++++-- .../postgresql/stdlib/go/query.sql.go | 19 +++++++++++++++++++ .../sqlc_arg/postgresql/stdlib/query.sql | 8 +++++++- 9 files changed, 93 insertions(+), 9 deletions(-) diff --git a/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go b/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go index 82d2d6ffa5..f51ae42348 100644 --- a/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go +++ b/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go @@ -4,8 +4,11 @@ package querytest -import () +import ( + "database/sql" +) type Foo struct { - Name string + Name string + Description sql.NullString } diff --git a/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go index 7b6cbe5f67..a9899dcc4e 100644 --- a/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go @@ -7,6 +7,7 @@ package querytest import ( "context" + "database/sql" ) const funcParamIdent = `-- name: FuncParamIdent :many @@ -62,3 +63,21 @@ func (q *Queries) FuncParamString(ctx context.Context, slug string) ([]string, e } return items, nil } + +const funcParamStringOptional = `-- name: FuncParamStringOptional :exec +UPDATE foo SET name = coalesce(?, name) +` + +func (q *Queries) FuncParamStringOptional(ctx context.Context, slug sql.NullString) error { + _, err := q.db.ExecContext(ctx, funcParamStringOptional, slug) + return err +} + +const funcParamStringRequired = `-- name: FuncParamStringRequired :exec +UPDATE foo SET description = ? +` + +func (q *Queries) FuncParamStringRequired(ctx context.Context, slug string) error { + _, err := q.db.ExecContext(ctx, funcParamStringRequired, slug) + return err +} diff --git a/internal/endtoend/testdata/sqlc_arg/mysql/query.sql b/internal/endtoend/testdata/sqlc_arg/mysql/query.sql index a8a16f7de9..66e3eef8d5 100644 --- a/internal/endtoend/testdata/sqlc_arg/mysql/query.sql +++ b/internal/endtoend/testdata/sqlc_arg/mysql/query.sql @@ -1,7 +1,13 @@ -CREATE TABLE foo (name text not null); +CREATE TABLE foo (name text not null, description text); /* name: FuncParamIdent :many */ SELECT name FROM foo WHERE name = sqlc.arg(slug); /* name: FuncParamString :many */ SELECT name FROM foo WHERE name = sqlc.arg('slug'); + +/* name: FuncParamStringOptional :exec */ +UPDATE foo SET name = coalesce(sqlc.arg('slug?'), name); + +/* name: FuncParamStringRequired :exec */ +UPDATE foo SET description = sqlc.arg('slug!'); diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go index 82d2d6ffa5..f51ae42348 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go @@ -4,8 +4,11 @@ package querytest -import () +import ( + "database/sql" +) type Foo struct { - Name string + Name string + Description sql.NullString } diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go index e5041b2e3e..af8c4be493 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go @@ -7,6 +7,7 @@ package querytest import ( "context" + "database/sql" ) const funcParamIdent = `-- name: FuncParamIdent :many @@ -56,3 +57,21 @@ func (q *Queries) FuncParamString(ctx context.Context, slug string) ([]string, e } return items, nil } + +const funcParamStringOptional = `-- name: FuncParamStringOptional :exec +UPDATE foo SET name = coalesce($1, name) +` + +func (q *Queries) FuncParamStringOptional(ctx context.Context, slug sql.NullString) error { + _, err := q.db.Exec(ctx, funcParamStringOptional, slug) + return err +} + +const funcParamStringRequired = `-- name: FuncParamStringRequired :exec +UPDATE foo SET description = $1 +` + +func (q *Queries) FuncParamStringRequired(ctx context.Context, slug string) error { + _, err := q.db.Exec(ctx, funcParamStringRequired, slug) + return err +} diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql index 9a8e98e223..8f7fd98aa9 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql @@ -1,7 +1,13 @@ -CREATE TABLE foo (name text not null); +CREATE TABLE foo (name text not null, description text); -- name: FuncParamIdent :many SELECT name FROM foo WHERE name = sqlc.arg(slug); -- name: FuncParamString :many SELECT name FROM foo WHERE name = sqlc.arg('slug'); + +-- name: FuncParamStringOptional :exec +UPDATE foo SET name = coalesce(sqlc.arg('slug?'), name); + +-- name: FuncParamStringRequired :exec +UPDATE foo SET description = sqlc.arg('slug!'); diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go index 82d2d6ffa5..f51ae42348 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go @@ -4,8 +4,11 @@ package querytest -import () +import ( + "database/sql" +) type Foo struct { - Name string + Name string + Description sql.NullString } diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go index 43d947b278..16280f0a61 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go @@ -7,6 +7,7 @@ package querytest import ( "context" + "database/sql" ) const funcParamIdent = `-- name: FuncParamIdent :many @@ -62,3 +63,21 @@ func (q *Queries) FuncParamString(ctx context.Context, slug string) ([]string, e } return items, nil } + +const funcParamStringOptional = `-- name: FuncParamStringOptional :exec +UPDATE foo SET name = coalesce($1, name) +` + +func (q *Queries) FuncParamStringOptional(ctx context.Context, slug sql.NullString) error { + _, err := q.db.ExecContext(ctx, funcParamStringOptional, slug) + return err +} + +const funcParamStringRequired = `-- name: FuncParamStringRequired :exec +UPDATE foo SET description = $1 +` + +func (q *Queries) FuncParamStringRequired(ctx context.Context, slug string) error { + _, err := q.db.ExecContext(ctx, funcParamStringRequired, slug) + return err +} diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql index 9a8e98e223..8f7fd98aa9 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql @@ -1,7 +1,13 @@ -CREATE TABLE foo (name text not null); +CREATE TABLE foo (name text not null, description text); -- name: FuncParamIdent :many SELECT name FROM foo WHERE name = sqlc.arg(slug); -- name: FuncParamString :many SELECT name FROM foo WHERE name = sqlc.arg('slug'); + +-- name: FuncParamStringOptional :exec +UPDATE foo SET name = coalesce(sqlc.arg('slug?'), name); + +-- name: FuncParamStringRequired :exec +UPDATE foo SET description = sqlc.arg('slug!'); From b95ee5fbc3900f8e4e2534d35edcb7a6294fc97a Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Fri, 8 Apr 2022 23:38:47 -0700 Subject: [PATCH 07/11] add named.ParamSet to represent a set of parameters in a query --- internal/compiler/resolve.go | 30 +++---- internal/sql/named/param.go | 4 +- internal/sql/named/param_set.go | 78 ++++++++++++++++++ internal/sql/named/param_set_test.go | 58 ++++++++++++++ internal/sql/named/param_test.go | 10 +-- internal/sql/rewrite/parameters.go | 115 +++++++-------------------- internal/sql/validate/param_ref.go | 1 + 7 files changed, 185 insertions(+), 111 deletions(-) create mode 100644 internal/sql/named/param_set.go create mode 100644 internal/sql/named/param_set_test.go diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index d9225ba7a5..1b33ad96c1 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -19,7 +19,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params map[int]named.Param) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet) ([]Parameter, error) { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -27,14 +27,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, var defaultTable *ast.TableName var tables []*ast.TableName - // fetch param fetches the named parameter at index `n` or a default substitution - // and returns whether it was found or not. - fetchParam := func(n int, defaultP named.Param) (named.Param, bool) { - p, ok := params[n] - p = named.Combine(p, defaultP) - return p, ok - } - typeMap := map[string]map[string]map[string]*catalog.Column{} indexTable := func(table catalog.Table) error { tables = append(tables, table.Rel) @@ -90,7 +82,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, case *limitOffset: defaultP := named.NewInferredParam("offset", true) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -103,7 +95,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, case *limitCount: defaultP := named.NewInferredParam("limit", true) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -130,7 +122,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } defaultP := named.NewUnspecifiedParam("") - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -192,7 +184,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } defaultP := named.NewInferredParam(key, c.IsNotNull) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -251,7 +243,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if c, ok := typeMap[schema][table.Name][key]; ok { defaultP := named.NewInferredParam(key, c.IsNotNull) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, Column: &Column{ @@ -321,7 +313,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } defaultP := named.NewInferredParam(defaultName, false) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -355,7 +347,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } defaultP := named.NewInferredParam(paramName, true) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -416,7 +408,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if c, ok := tableMap[key]; ok { defaultP := named.NewInferredParam(key, c.IsNotNull) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, Column: &Column{ @@ -443,7 +435,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } col := toColumn(n.TypeName) defaultP := named.NewInferredParam(col.Name, col.NotNull) - p, _ := fetchParam(ref.ref.Number, defaultP) + p, _ := params.FetchMerge(ref.ref.Number, defaultP) col.Name = p.Name() col.NotNull = p.NotNull() @@ -523,7 +515,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, key = ref.name } defaultP := named.NewInferredParam(key, c.IsNotNull) - p, isNamed := fetchParam(ref.ref.Number, defaultP) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: number, Column: &Column{ diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go index fcebd65814..d660144a57 100644 --- a/internal/sql/named/param.go +++ b/internal/sql/named/param.go @@ -106,9 +106,9 @@ func (p Param) NotNull() bool { return null } -// Combine creates a new param from 2 partially specified params +// mergeParam creates a new param from 2 partially specified params // If the parameters have different names, the first is preferred -func Combine(a, b Param) Param { +func mergeParam(a, b Param) Param { name := a.name if name == "" { name = b.name diff --git a/internal/sql/named/param_set.go b/internal/sql/named/param_set.go new file mode 100644 index 0000000000..b30de738b3 --- /dev/null +++ b/internal/sql/named/param_set.go @@ -0,0 +1,78 @@ +package named + +// ParamSet represents a set of parameters for a single query +type ParamSet struct { + // does this engine support named parameters? + hasNamedSupport bool + // the set of currently tracked named parameters + namedParams map[string]Param + // the locations of each of the named parameters + namedLocs map[string][]int + // a map of positions currently used + positionToName map[int]string + // argn keeps track of the last checked positional parameter used + argn int +} + +func (p *ParamSet) nextArgNum() int { + for { + if _, ok := p.positionToName[p.argn]; !ok { + return p.argn + } + + p.argn++ + } +} + +// Add adds a parameter to this set and returns the numbered location used for it +func (p *ParamSet) Add(param Param) int { + name := param.name + existing, ok := p.namedParams[name] + + p.namedParams[name] = mergeParam(existing, param) + if ok && p.hasNamedSupport { + return p.namedLocs[name][0] + } + + argn := p.nextArgNum() + p.positionToName[argn] = name + p.namedLocs[name] = append(p.namedLocs[name], argn) + return argn +} + +// FetchMerge fetches an indexed parameter, and merges `mergeP` into it +// Returns: the merged parameter and whether it was a named parameter +func (p *ParamSet) FetchMerge(idx int, mergeP Param) (param Param, isNamed bool) { + name, exists := p.positionToName[idx] + if !exists || name == "" { + return mergeP, false + } + + param, ok := p.namedParams[name] + if !ok { + return mergeP, false + } + + return mergeParam(param, mergeP), true +} + +// NewParamSet creates a set of parameters with the given list of already used positions +func NewParamSet(positionsUsed map[int]bool, hasNamedSupport bool) *ParamSet { + positionToName := make(map[int]string, len(positionsUsed)) + for index, used := range positionsUsed { + if !used { + continue + } + + // assume the previously used params have no name + positionToName[index] = "" + } + + return &ParamSet{ + argn: 1, + namedParams: make(map[string]Param), + namedLocs: make(map[string][]int), + hasNamedSupport: hasNamedSupport, + positionToName: positionToName, + } +} diff --git a/internal/sql/named/param_set_test.go b/internal/sql/named/param_set_test.go new file mode 100644 index 0000000000..30e0cfc7da --- /dev/null +++ b/internal/sql/named/param_set_test.go @@ -0,0 +1,58 @@ +package named + +import "testing" + +func TestParamSet_Add(t *testing.T) { + t.Parallel() + + type test struct { + pset *ParamSet + param Param + expected int + } + + named := NewParamSet(nil, true) + populatedNamed := NewParamSet(map[int]bool{1: true, 2: true, 4: true, 5: true, 6: true}, true) + populatedUnnamed := NewParamSet(map[int]bool{1: true, 2: true, 4: true, 5: true, 6: true}, false) + unnamed := NewParamSet(nil, false) + p1 := NewUnspecifiedParam("hello") + p2 := NewUnspecifiedParam("world") + + tests := []test{ + // First parameter should be 1 + {named, p1, 1}, + // Duplicate first parameters should be 1 + {named, p1, 1}, + // A new parameter receives a new parameter number + {named, p2, 2}, + // An additional new parameter does _not_ receive a new + {named, p2, 2}, + + // First parameter should be 1 + {unnamed, p1, 1}, + // Duplicate first parameters should increment argn + {unnamed, p1, 2}, + // A new parameter receives a new parameter number + {unnamed, p2, 3}, + // An additional new parameter still does receive a new argn + {unnamed, p2, 4}, + + // First parameter of a pre-populated should be 3 + {populatedNamed, p1, 3}, + {populatedNamed, p1, 3}, + {populatedNamed, p2, 7}, + {populatedNamed, p2, 7}, + + {populatedUnnamed, p1, 3}, + {populatedUnnamed, p1, 7}, + {populatedUnnamed, p2, 8}, + {populatedUnnamed, p2, 9}, + } + + for _, spec := range tests { + actual := spec.pset.Add(spec.param) + if actual != spec.expected { + t.Errorf("ParamSet.Add(%s) expected %v; got %v", spec.param.name, spec.expected, actual) + } + } +} diff --git a/internal/sql/named/param_test.go b/internal/sql/named/param_test.go index 49748f6d70..fb5f38ecdd 100644 --- a/internal/sql/named/param_test.go +++ b/internal/sql/named/param_test.go @@ -2,7 +2,7 @@ package named import "testing" -func TestCombineNullability(t *testing.T) { +func TestMergeParamNullability(t *testing.T) { type test struct { a Param b Param @@ -39,21 +39,21 @@ func TestCombineNullability(t *testing.T) { for _, spec := range tests { a := spec.a b := spec.b - actual := Combine(a, b).NotNull() + actual := mergeParam(a, b).NotNull() expected := spec.notNull if actual != expected { t.Errorf("Combine(%s,%s) expected %v; got %v", a.nullability, b.nullability, expected, actual) } // We have already tried Combine(a, b) the same result should be true for Combine(b, a) - actual = Combine(b, a).NotNull() + actual = mergeParam(b, a).NotNull() if actual != expected { t.Errorf("Combine(%s,%s) expected %v; got %v", b.nullability, a.nullability, expected, actual) } } } -func TestCombineName(t *testing.T) { +func TestMergeParamName(t *testing.T) { type test struct { a Param b Param @@ -77,7 +77,7 @@ func TestCombineName(t *testing.T) { for _, spec := range tests { a := spec.a b := spec.b - actual := Combine(a, b).Name() + actual := mergeParam(a, b).Name() expected := spec.name if actual != expected { t.Errorf("Combine(%s,%s) expected %v; got %v", a.name, b.name, expected, actual) diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index 064fa08ffc..fb92410d0e 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,21 +41,6 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -type namedParameter struct { - param named.Param - locs []int -} - -// Add a new instance of this parameter of the same name -func (n namedParameter) AddInstance(loc int, p named.Param) namedParameter { - param := named.Combine(n.param, p) - locs := append(n.locs, loc) - return namedParameter{ - param: param, - locs: locs, - } -} - // paramFromName takes a user-defined parameter name, with an optional suffix of // ? (nullable), or ! (non-null) and builds the appropiate parameter func paramFromName(name string) named.Param { @@ -75,17 +60,16 @@ func paramFromName(name string) named.Param { return named.NewUnspecifiedParam(name) } -func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, map[int]named.Param, []source.Edit) { +func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, *named.ParamSet, []source.Edit) { foundFunc := astutils.Search(raw, named.IsParamFunc) foundSign := astutils.Search(raw, named.IsParamSign) + hasNamedParameterSupport := engine != config.EngineMySQL + allParams := named.NewParamSet(numbs, hasNamedParameterSupport) + if len(foundFunc.Items)+len(foundSign.Items) == 0 { - return raw, map[int]named.Param{}, nil + return raw, allParams, nil } - hasNamedParameterSupport := engine != config.EngineMySQL - - args := map[string]namedParameter{} - argn := 0 var edits []source.Edit node := astutils.Apply(raw, func(cr *astutils.Cursor) bool { node := cr.Node() @@ -93,26 +77,14 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamFunc(node): fun := node.(*ast.FuncCall) paramName, isConst := flatten(fun.Args) + param := paramFromName(paramName) + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: fun.Location, + }) - if namedP, ok := args[param.Name()]; ok && hasNamedParameterSupport { - cr.Replace(&ast.ParamRef{ - Number: namedP.locs[0], - Location: fun.Location, - }) - } else { - // Find the arg number that has not yet been used - argn++ - for numbs[argn] { - argn++ - } - - args[param.Name()] = args[param.Name()].AddInstance(argn, param) - cr.Replace(&ast.ParamRef{ - Number: argn, - Location: fun.Location, - }) - } // TODO: This code assumes that sqlc.arg(name) is on a single line var old, replace string if isConst { @@ -123,7 +95,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param.Name()].locs[0]) + replace = fmt.Sprintf("$%d", argn) } edits = append(edits, source.Edit{ Location: fun.Location - raw.StmtLocation, @@ -137,32 +109,22 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, cast := expr.Rexpr.(*ast.TypeCast) paramName, _ := flatten(cast.Arg) param := paramFromName(paramName) - if p, ok := args[param.Name()]; ok { - cast.Arg = &ast.ParamRef{ - Number: p.locs[0], - Location: expr.Location, - } - cr.Replace(cast) - } else { - argn++ - for numbs[argn] { - argn++ - } - - args[param.Name()] = args[param.Name()].AddInstance(argn, param) - cast.Arg = &ast.ParamRef{ - Number: argn, - Location: expr.Location, - } - cr.Replace(cast) + + argn := allParams.Add(param) + cast.Arg = &ast.ParamRef{ + Number: argn, + Location: expr.Location, } + cr.Replace(cast) + // TODO: This code assumes that @foo::bool is on a single line var replace string if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param.Name()].locs[0]) + replace = fmt.Sprintf("$%d", argn) } + edits = append(edits, source.Edit{ Location: expr.Location - raw.StmtLocation, Old: fmt.Sprintf("@%s", paramName), @@ -174,29 +136,19 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, expr := node.(*ast.A_Expr) paramName, _ := flatten(expr.Rexpr) param := paramFromName(paramName) - if p, ok := args[param.Name()]; ok { - cr.Replace(&ast.ParamRef{ - Number: p.locs[0], - Location: expr.Location, - }) - } else { - argn++ - for numbs[argn] { - argn++ - } - - args[param.Name()] = args[param.Name()].AddInstance(argn, param) - cr.Replace(&ast.ParamRef{ - Number: argn, - Location: expr.Location, - }) - } + + argn := allParams.Add(param) + cr.Replace(&ast.ParamRef{ + Number: argn, + Location: expr.Location, + }) + // TODO: This code assumes that @foo is on a single line var replace string if engine == config.EngineMySQL || !dollar { replace = "?" } else { - replace = fmt.Sprintf("$%d", args[param.Name()].locs[0]) + replace = fmt.Sprintf("$%d", argn) } edits = append(edits, source.Edit{ @@ -211,12 +163,5 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, } }, nil) - paramByLoc := map[int]named.Param{} - for _, namedParam := range args { - for _, loc := range namedParam.locs { - paramByLoc[loc] = namedParam.param - } - } - - return node.(*ast.RawStmt), paramByLoc, edits + return node.(*ast.RawStmt), allParams, edits } diff --git a/internal/sql/validate/param_ref.go b/internal/sql/validate/param_ref.go index fbec8f9066..170a158527 100644 --- a/internal/sql/validate/param_ref.go +++ b/internal/sql/validate/param_ref.go @@ -3,6 +3,7 @@ package validate import ( "errors" "fmt" + "github.com/kyleconroy/sqlc/internal/sql/ast" "github.com/kyleconroy/sqlc/internal/sql/astutils" "github.com/kyleconroy/sqlc/internal/sql/sqlerr" From a7fa41cd793eae67c8f8f82f3d94d047d66eb517 Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Sat, 30 Apr 2022 10:10:33 -0700 Subject: [PATCH 08/11] revert sqlc.arg tests with `!?` --- .../testdata/sqlc_arg/mysql/go/models.go | 7 ++----- .../testdata/sqlc_arg/mysql/go/query.sql.go | 19 ------------------- .../testdata/sqlc_arg/mysql/query.sql | 8 +------- .../sqlc_arg/postgresql/pgx/go/models.go | 7 ++----- .../sqlc_arg/postgresql/pgx/go/query.sql.go | 19 ------------------- .../sqlc_arg/postgresql/pgx/query.sql | 8 +------- .../sqlc_arg/postgresql/stdlib/go/models.go | 7 ++----- .../postgresql/stdlib/go/query.sql.go | 19 ------------------- .../sqlc_arg/postgresql/stdlib/query.sql | 8 +------- 9 files changed, 9 insertions(+), 93 deletions(-) diff --git a/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go b/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go index f51ae42348..82d2d6ffa5 100644 --- a/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go +++ b/internal/endtoend/testdata/sqlc_arg/mysql/go/models.go @@ -4,11 +4,8 @@ package querytest -import ( - "database/sql" -) +import () type Foo struct { - Name string - Description sql.NullString + Name string } diff --git a/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go index a9899dcc4e..7b6cbe5f67 100644 --- a/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_arg/mysql/go/query.sql.go @@ -7,7 +7,6 @@ package querytest import ( "context" - "database/sql" ) const funcParamIdent = `-- name: FuncParamIdent :many @@ -63,21 +62,3 @@ func (q *Queries) FuncParamString(ctx context.Context, slug string) ([]string, e } return items, nil } - -const funcParamStringOptional = `-- name: FuncParamStringOptional :exec -UPDATE foo SET name = coalesce(?, name) -` - -func (q *Queries) FuncParamStringOptional(ctx context.Context, slug sql.NullString) error { - _, err := q.db.ExecContext(ctx, funcParamStringOptional, slug) - return err -} - -const funcParamStringRequired = `-- name: FuncParamStringRequired :exec -UPDATE foo SET description = ? -` - -func (q *Queries) FuncParamStringRequired(ctx context.Context, slug string) error { - _, err := q.db.ExecContext(ctx, funcParamStringRequired, slug) - return err -} diff --git a/internal/endtoend/testdata/sqlc_arg/mysql/query.sql b/internal/endtoend/testdata/sqlc_arg/mysql/query.sql index 66e3eef8d5..a8a16f7de9 100644 --- a/internal/endtoend/testdata/sqlc_arg/mysql/query.sql +++ b/internal/endtoend/testdata/sqlc_arg/mysql/query.sql @@ -1,13 +1,7 @@ -CREATE TABLE foo (name text not null, description text); +CREATE TABLE foo (name text not null); /* name: FuncParamIdent :many */ SELECT name FROM foo WHERE name = sqlc.arg(slug); /* name: FuncParamString :many */ SELECT name FROM foo WHERE name = sqlc.arg('slug'); - -/* name: FuncParamStringOptional :exec */ -UPDATE foo SET name = coalesce(sqlc.arg('slug?'), name); - -/* name: FuncParamStringRequired :exec */ -UPDATE foo SET description = sqlc.arg('slug!'); diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go index f51ae42348..82d2d6ffa5 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/models.go @@ -4,11 +4,8 @@ package querytest -import ( - "database/sql" -) +import () type Foo struct { - Name string - Description sql.NullString + Name string } diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go index af8c4be493..e5041b2e3e 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/go/query.sql.go @@ -7,7 +7,6 @@ package querytest import ( "context" - "database/sql" ) const funcParamIdent = `-- name: FuncParamIdent :many @@ -57,21 +56,3 @@ func (q *Queries) FuncParamString(ctx context.Context, slug string) ([]string, e } return items, nil } - -const funcParamStringOptional = `-- name: FuncParamStringOptional :exec -UPDATE foo SET name = coalesce($1, name) -` - -func (q *Queries) FuncParamStringOptional(ctx context.Context, slug sql.NullString) error { - _, err := q.db.Exec(ctx, funcParamStringOptional, slug) - return err -} - -const funcParamStringRequired = `-- name: FuncParamStringRequired :exec -UPDATE foo SET description = $1 -` - -func (q *Queries) FuncParamStringRequired(ctx context.Context, slug string) error { - _, err := q.db.Exec(ctx, funcParamStringRequired, slug) - return err -} diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql index 8f7fd98aa9..9a8e98e223 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/pgx/query.sql @@ -1,13 +1,7 @@ -CREATE TABLE foo (name text not null, description text); +CREATE TABLE foo (name text not null); -- name: FuncParamIdent :many SELECT name FROM foo WHERE name = sqlc.arg(slug); -- name: FuncParamString :many SELECT name FROM foo WHERE name = sqlc.arg('slug'); - --- name: FuncParamStringOptional :exec -UPDATE foo SET name = coalesce(sqlc.arg('slug?'), name); - --- name: FuncParamStringRequired :exec -UPDATE foo SET description = sqlc.arg('slug!'); diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go index f51ae42348..82d2d6ffa5 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/models.go @@ -4,11 +4,8 @@ package querytest -import ( - "database/sql" -) +import () type Foo struct { - Name string - Description sql.NullString + Name string } diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go index 16280f0a61..43d947b278 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/go/query.sql.go @@ -7,7 +7,6 @@ package querytest import ( "context" - "database/sql" ) const funcParamIdent = `-- name: FuncParamIdent :many @@ -63,21 +62,3 @@ func (q *Queries) FuncParamString(ctx context.Context, slug string) ([]string, e } return items, nil } - -const funcParamStringOptional = `-- name: FuncParamStringOptional :exec -UPDATE foo SET name = coalesce($1, name) -` - -func (q *Queries) FuncParamStringOptional(ctx context.Context, slug sql.NullString) error { - _, err := q.db.ExecContext(ctx, funcParamStringOptional, slug) - return err -} - -const funcParamStringRequired = `-- name: FuncParamStringRequired :exec -UPDATE foo SET description = $1 -` - -func (q *Queries) FuncParamStringRequired(ctx context.Context, slug string) error { - _, err := q.db.ExecContext(ctx, funcParamStringRequired, slug) - return err -} diff --git a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql index 8f7fd98aa9..9a8e98e223 100644 --- a/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql +++ b/internal/endtoend/testdata/sqlc_arg/postgresql/stdlib/query.sql @@ -1,13 +1,7 @@ -CREATE TABLE foo (name text not null, description text); +CREATE TABLE foo (name text not null); -- name: FuncParamIdent :many SELECT name FROM foo WHERE name = sqlc.arg(slug); -- name: FuncParamString :many SELECT name FROM foo WHERE name = sqlc.arg('slug'); - --- name: FuncParamStringOptional :exec -UPDATE foo SET name = coalesce(sqlc.arg('slug?'), name); - --- name: FuncParamStringRequired :exec -UPDATE foo SET description = sqlc.arg('slug!'); From e4c23a32aad1c537df8d78599ddde5cbfb10b868 Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Sat, 30 Apr 2022 11:35:43 -0700 Subject: [PATCH 09/11] add sqlc.narg for params users can override to nullable --- internal/compiler/resolve.go | 2 +- internal/source/code.go | 30 ++-- internal/source/mutate_test.go | 210 +++++++++++++++++++++++++++ internal/sql/named/is.go | 6 +- internal/sql/named/param.go | 14 +- internal/sql/named/param_set_test.go | 4 +- internal/sql/named/param_test.go | 14 +- internal/sql/rewrite/parameters.go | 50 +++---- internal/sql/validate/func_call.go | 2 +- 9 files changed, 271 insertions(+), 61 deletions(-) create mode 100644 internal/source/mutate_test.go diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 1b33ad96c1..4074116e3b 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -121,7 +121,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, dataType = "string" } - defaultP := named.NewUnspecifiedParam("") + defaultP := named.NewParam("") p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) a = append(a, Parameter{ Number: ref.ref.Number, diff --git a/internal/source/code.go b/internal/source/code.go index b84324b55f..9a6ed077d3 100644 --- a/internal/source/code.go +++ b/internal/source/code.go @@ -54,25 +54,31 @@ func Mutate(raw string, a []Edit) (string, error) { if len(a) == 0 { return raw, nil } + sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location }) + s := raw - for _, edit := range a { + for idx, edit := range a { start := edit.Location - if start > len(s) { + if start > len(s) || start < 0 { return "", fmt.Errorf("edit start location is out of bounds") } - if len(edit.New) <= 0 { - return "", fmt.Errorf("empty edit contents") - } - if len(edit.Old) <= 0 { - return "", fmt.Errorf("empty edit contents") + + stop := edit.Location + len(edit.Old) + if stop > len(s) { + return "", fmt.Errorf("edit stop location is out of bounds") } - stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty - if stop < len(s) { - s = s[:start] + edit.New + s[stop+1:] - } else { - s = s[:start] + edit.New + + // If this is not the first edit, (applied backwards), check if + // this edit overlaps the previous one (and is therefore a developer error) + if idx != 0 { + prevEdit := a[idx-1] + if prevEdit.Location < edit.Location+len(edit.Old) { + return "", fmt.Errorf("2 edits overlap") + } } + + s = s[:start] + edit.New + s[stop:] } return s, nil } diff --git a/internal/source/mutate_test.go b/internal/source/mutate_test.go new file mode 100644 index 0000000000..dd76888796 --- /dev/null +++ b/internal/source/mutate_test.go @@ -0,0 +1,210 @@ +package source + +import ( + "fmt" + "testing" +) + +// newEdit is a testing helper for quickly generating Edits +func newEdit(loc int, old, new string) Edit { + return Edit{Location: loc, Old: old, New: new} +} + +// TestMutateSingle tests almost every possibility of a single edit +func TestMutateSingle(t *testing.T) { + type test struct { + input string + edit Edit + expected string + } + + tests := []test{ + // Simple edits that replace everything + {"", newEdit(0, "", ""), ""}, + {"a", newEdit(0, "a", "A"), "A"}, + {"abcde", newEdit(0, "abcde", "fghij"), "fghij"}, + {"", newEdit(0, "", "fghij"), "fghij"}, + {"abcde", newEdit(0, "abcde", ""), ""}, + + // Edits that start at the very beginning (But don't cover the whole range) + {"abcde", newEdit(0, "a", "A"), "Abcde"}, + {"abcde", newEdit(0, "ab", "AB"), "ABcde"}, + {"abcde", newEdit(0, "abc", "ABC"), "ABCde"}, + {"abcde", newEdit(0, "abcd", "ABCD"), "ABCDe"}, + + // The above repeated, but with different lengths + {"abcde", newEdit(0, "a", ""), "bcde"}, + {"abcde", newEdit(0, "ab", "A"), "Acde"}, + {"abcde", newEdit(0, "abc", "AB"), "ABde"}, + {"abcde", newEdit(0, "abcd", "AB"), "ABe"}, + + // Edits that touch the end (but don't cover the whole range) + {"abcde", newEdit(4, "e", "E"), "abcdE"}, + {"abcde", newEdit(3, "de", "DE"), "abcDE"}, + {"abcde", newEdit(2, "cde", "CDE"), "abCDE"}, + {"abcde", newEdit(1, "bcde", "BCDE"), "aBCDE"}, + + // The above repeated, but with different lengths + {"abcde", newEdit(4, "e", ""), "abcd"}, + {"abcde", newEdit(3, "de", "D"), "abcD"}, + {"abcde", newEdit(2, "cde", "CD"), "abCD"}, + {"abcde", newEdit(1, "bcde", "BC"), "aBC"}, + + // Raw insertions / deletions + {"abcde", newEdit(0, "", "_"), "_abcde"}, + {"abcde", newEdit(1, "", "_"), "a_bcde"}, + {"abcde", newEdit(2, "", "_"), "ab_cde"}, + {"abcde", newEdit(3, "", "_"), "abc_de"}, + {"abcde", newEdit(4, "", "_"), "abcd_e"}, + {"abcde", newEdit(5, "", "_"), "abcde_"}, + } + + origTests := tests + // Generate the reverse mutations, for every edit - the opposite edit that makes it "undo" + for _, spec := range origTests { + tests = append(tests, test{ + input: spec.expected, + edit: newEdit(spec.edit.Location, spec.edit.New, spec.edit.Old), + expected: spec.input, + }) + } + + for _, spec := range tests { + expected := spec.expected + + actual, err := Mutate(spec.input, []Edit{spec.edit}) + testName := fmt.Sprintf("Mutate(%s, Edit{%v, %v -> %v})", spec.input, spec.edit.Location, spec.edit.Old, spec.edit.New) + if err != nil { + t.Errorf("%s should not error (%v)", testName, err) + continue + } + + if actual != expected { + t.Errorf("%s expected %v; got %v", testName, expected, actual) + } + } +} + +// TestMutateMulti tests combinations of edits +func TestMutateMulti(t *testing.T) { + type test struct { + input string + edit1 Edit + edit2 Edit + expected string + } + + tests := []test{ + // Edits that are >1 character from each other + {"abcde", newEdit(0, "a", "A"), newEdit(2, "c", "C"), "AbCde"}, + {"abcde", newEdit(0, "a", "A"), newEdit(2, "c", "C"), "AbCde"}, + + // 2 edits bump right up next to each other + {"abcde", newEdit(0, "abc", ""), newEdit(3, "de", "DE"), "DE"}, + {"abcde", newEdit(0, "abc", "ABC"), newEdit(3, "de", ""), "ABC"}, + {"abcde", newEdit(0, "abc", "ABC"), newEdit(3, "de", "DE"), "ABCDE"}, + {"abcde", newEdit(1, "b", "BB"), newEdit(2, "c", "CC"), "aBBCCde"}, + + // 2 edits bump next to each other, but don't cover the whole string + {"abcdef", newEdit(1, "bc", "C"), newEdit(3, "de", "D"), "aCDf"}, + {"abcde", newEdit(1, "bc", "CCCC"), newEdit(3, "d", "DDD"), "aCCCCDDDe"}, + + // lengthening edits + {"abcde", newEdit(1, "b", "BBBB"), newEdit(2, "c", "CCCC"), "aBBBBCCCCde"}, + } + + origTests := tests + // Generate the edits in opposite order mutations, source edits should be independent of + // the order the edits are specified + for _, spec := range origTests { + tests = append(tests, test{ + input: spec.input, + edit1: spec.edit2, + edit2: spec.edit1, + expected: spec.expected, + }) + } + + for _, spec := range tests { + expected := spec.expected + + actual, err := Mutate(spec.input, []Edit{spec.edit1, spec.edit2}) + testName := fmt.Sprintf("Mutate(%s, Edits{(%v, %v -> %v), (%v, %v -> %v)})", spec.input, + spec.edit1.Location, spec.edit1.Old, spec.edit1.New, + spec.edit2.Location, spec.edit2.Old, spec.edit2.New) + + if err != nil { + t.Errorf("%s should not error (%v)", testName, err) + continue + } + + if actual != expected { + t.Errorf("%s expected %v; got %v", testName, expected, actual) + } + } +} + +// TestMutateErrorSingle test errors are generated for trivially incorrect single edits +func TestMutateErrorSingle(t *testing.T) { + type test struct { + input string + edit Edit + } + + tests := []test{ + // old text is longer than input text + {"", newEdit(0, "a", "A")}, + {"a", newEdit(0, "aa", "A")}, + {"hello", newEdit(0, "hello!", "A")}, + + // negative indexes + {"aaa", newEdit(-1, "aa", "A")}, + {"aaa", newEdit(-2, "aa", "A")}, + {"aaa", newEdit(-100, "aa", "A")}, + } + + for _, spec := range tests { + edit := spec.edit + + _, err := Mutate(spec.input, []Edit{edit}) + testName := fmt.Sprintf("Mutate(%s, Edit{%v, %v -> %v})", spec.input, edit.Location, edit.Old, edit.New) + if err == nil { + t.Errorf("%s should error (%v)", testName, err) + continue + } + } +} + +// TestMutateErrorMulti tests error that can only happen across multiple errors +func TestMutateErrorMulti(t *testing.T) { + type test struct { + input string + edit1 Edit + edit2 Edit + } + + tests := []test{ + // These edits overlap each other, and are therefore undefined + {"abcdef", newEdit(0, "a", ""), newEdit(0, "a", "A")}, + {"abcdef", newEdit(0, "ab", ""), newEdit(1, "ab", "AB")}, + {"abcdef", newEdit(0, "abc", ""), newEdit(2, "abc", "ABC")}, + + // the last edit is longer than the string itself + {"abcdef", newEdit(0, "abcdefghi", ""), newEdit(2, "abc", "ABC")}, + + // negative indexes + {"abcdef", newEdit(-1, "abc", ""), newEdit(3, "abc", "ABC")}, + {"abcdef", newEdit(0, "abc", ""), newEdit(-1, "abc", "ABC")}, + } + + for _, spec := range tests { + actual, err := Mutate(spec.input, []Edit{spec.edit1, spec.edit2}) + testName := fmt.Sprintf("Mutate(%s, Edits{(%v, %v -> %v), (%v, %v -> %v)})", spec.input, + spec.edit1.Location, spec.edit1.Old, spec.edit1.New, + spec.edit2.Location, spec.edit2.Old, spec.edit2.New) + + if err == nil { + t.Errorf("%s should error, but got (%v)", testName, actual) + } + } +} diff --git a/internal/sql/named/is.go b/internal/sql/named/is.go index 5421a85bb1..ba26c645d2 100644 --- a/internal/sql/named/is.go +++ b/internal/sql/named/is.go @@ -5,15 +5,19 @@ import ( "github.com/kyleconroy/sqlc/internal/sql/astutils" ) +// IsParamFunc fulfills the astutils.Search func IsParamFunc(node ast.Node) bool { call, ok := node.(*ast.FuncCall) if !ok { return false } + if call.Func == nil { return false } - return call.Func.Schema == "sqlc" && call.Func.Name == "arg" + + isValid := call.Func.Schema == "sqlc" && (call.Func.Name == "arg" || call.Func.Name == "narg") + return isValid } func IsParamSign(node ast.Node) bool { diff --git a/internal/sql/named/param.go b/internal/sql/named/param.go index d660144a57..ec29e6184d 100644 --- a/internal/sql/named/param.go +++ b/internal/sql/named/param.go @@ -46,8 +46,8 @@ type Param struct { nullability nullability } -// NewUnspecifiedParam builds a new params with unspecified nullability -func NewUnspecifiedParam(name string) Param { +// NewParam builds a new params with unspecified nullability +func NewParam(name string) Param { return Param{name: name, nullability: nullUnspecified} } @@ -60,13 +60,9 @@ func NewInferredParam(name string, notNull bool) Param { return Param{name: name, nullability: inferredNull} } -// NewUserDefinedParam creates a new param with the user specified -// by the end user -func NewUserDefinedParam(name string, notNull bool) Param { - if notNull { - return Param{name: name, nullability: notNullable} - } - +// NewUserNullableParam is a parameter that has been overridden +// by the user to be nullable. +func NewUserNullableParam(name string) Param { return Param{name: name, nullability: nullable} } diff --git a/internal/sql/named/param_set_test.go b/internal/sql/named/param_set_test.go index 30e0cfc7da..99b7ed0575 100644 --- a/internal/sql/named/param_set_test.go +++ b/internal/sql/named/param_set_test.go @@ -15,8 +15,8 @@ func TestParamSet_Add(t *testing.T) { populatedNamed := NewParamSet(map[int]bool{1: true, 2: true, 4: true, 5: true, 6: true}, true) populatedUnnamed := NewParamSet(map[int]bool{1: true, 2: true, 4: true, 5: true, 6: true}, false) unnamed := NewParamSet(nil, false) - p1 := NewUnspecifiedParam("hello") - p2 := NewUnspecifiedParam("world") + p1 := NewParam("hello") + p2 := NewParam("world") tests := []test{ // First parameter should be 1 diff --git a/internal/sql/named/param_test.go b/internal/sql/named/param_test.go index fb5f38ecdd..2643f8b308 100644 --- a/internal/sql/named/param_test.go +++ b/internal/sql/named/param_test.go @@ -11,11 +11,10 @@ func TestMergeParamNullability(t *testing.T) { } name := "name" - unspec := NewUnspecifiedParam(name) + unspec := NewParam(name) inferredNotNull := NewInferredParam(name, true) inferredNull := NewInferredParam(name, false) - userDefNotNull := NewUserDefinedParam(name, true) - userDefNull := NewUserDefinedParam(name, false) + userDefNull := NewUserNullableParam(name) const notNull = true const null = false @@ -24,16 +23,13 @@ func TestMergeParamNullability(t *testing.T) { // Unspecified nullability parameter works {unspec, inferredNotNull, notNull, "Unspec + inferred(not null) = not null"}, {unspec, inferredNull, null, "Unspec + inferred(not null) = null"}, - {unspec, userDefNotNull, notNull, "Unspec + userdef(not null) = not null"}, {unspec, userDefNull, null, "Unspec + userdef(null) = null"}, // Inferred nullability agreeing with user defined nullabilty - {inferredNotNull, userDefNotNull, notNull, "inferred(not null) + userdef(not null) = not null"}, {inferredNull, userDefNull, null, "inferred(null) + userdef(null) = null"}, // Inferred nullability disagreeing with user defined nullabilty {inferredNotNull, userDefNull, null, "inferred(not null) + userdef(null) = null"}, - {inferredNull, userDefNotNull, notNull, "inferred(null) + userdef(not null) = not null"}, } for _, spec := range tests { @@ -60,9 +56,9 @@ func TestMergeParamName(t *testing.T) { name string } - a := NewUnspecifiedParam("a") - b := NewUnspecifiedParam("b") - blank := NewUnspecifiedParam("") + a := NewParam("a") + b := NewParam("b") + blank := NewParam("") tests := []test{ // should prefer the first param's name if both specified diff --git a/internal/sql/rewrite/parameters.go b/internal/sql/rewrite/parameters.go index fb92410d0e..250d967e76 100644 --- a/internal/sql/rewrite/parameters.go +++ b/internal/sql/rewrite/parameters.go @@ -41,23 +41,28 @@ func isNamedParamSignCast(node ast.Node) bool { return astutils.Join(expr.Name, ".") == "@" && cast } -// paramFromName takes a user-defined parameter name, with an optional suffix of -// ? (nullable), or ! (non-null) and builds the appropiate parameter -func paramFromName(name string) named.Param { - if len(name) == 0 { - return named.NewUnspecifiedParam(name) +// paramFromFuncCall creates a param from sqlc.n?arg() calls return the +// parameter and whether the parameter name was specified a best guess as its +// "source" string representation (used for replacing this function call in the +// original SQL query) +func paramFromFuncCall(call *ast.FuncCall) (named.Param, string) { + paramName, isConst := flatten(call.Args) + + // origName keeps track of how the parameter was specified in the source SQL + origName := paramName + if isConst { + origName = fmt.Sprintf("'%s'", paramName) } - last := name[len(name)-1] - if last == '!' { - return named.NewUserDefinedParam(name[:len(name)-1], true) + param := named.NewParam(paramName) + if call.Func.Name == "narg" { + param = named.NewUserNullableParam(paramName) } - if last == '?' { - return named.NewUserDefinedParam(name[:len(name)-1], false) - } - - return named.NewUnspecifiedParam(name) + // TODO: This code assumes that sqlc.arg(name) / sqlc.narg(name) is on a single line + // with no extraneous spaces (or any non-significant tokens for that matter) + origText := fmt.Sprintf("%s.%s(%s)", call.Func.Schema, call.Func.Name, origName) + return param, origText } func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, dollar bool) (*ast.RawStmt, *named.ParamSet, []source.Edit) { @@ -76,30 +81,23 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, switch { case named.IsParamFunc(node): fun := node.(*ast.FuncCall) - paramName, isConst := flatten(fun.Args) - - param := paramFromName(paramName) + param, origText := paramFromFuncCall(fun) argn := allParams.Add(param) cr.Replace(&ast.ParamRef{ Number: argn, Location: fun.Location, }) - // TODO: This code assumes that sqlc.arg(name) is on a single line - var old, replace string - if isConst { - old = fmt.Sprintf("sqlc.arg('%s')", paramName) - } else { - old = fmt.Sprintf("sqlc.arg(%s)", paramName) - } + var replace string if engine == config.EngineMySQL || !dollar { replace = "?" } else { replace = fmt.Sprintf("$%d", argn) } + edits = append(edits, source.Edit{ Location: fun.Location - raw.StmtLocation, - Old: old, + Old: origText, New: replace, }) return false @@ -108,7 +106,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, expr := node.(*ast.A_Expr) cast := expr.Rexpr.(*ast.TypeCast) paramName, _ := flatten(cast.Arg) - param := paramFromName(paramName) + param := named.NewParam(paramName) argn := allParams.Add(param) cast.Arg = &ast.ParamRef{ @@ -135,7 +133,7 @@ func NamedParameters(engine config.Engine, raw *ast.RawStmt, numbs map[int]bool, case named.IsParamSign(node): expr := node.(*ast.A_Expr) paramName, _ := flatten(expr.Rexpr) - param := paramFromName(paramName) + param := named.NewParam(paramName) argn := allParams.Add(param) cr.Replace(&ast.ParamRef{ diff --git a/internal/sql/validate/func_call.go b/internal/sql/validate/func_call.go index 85c3df0d7e..5fbac048d2 100644 --- a/internal/sql/validate/func_call.go +++ b/internal/sql/validate/func_call.go @@ -34,7 +34,7 @@ func (v *funcCallVisitor) Visit(node ast.Node) astutils.Visitor { // Custom validation for sqlc.arg // TODO: Replace this once type-checking is implemented if fn.Schema == "sqlc" { - if fn.Name != "arg" { + if !(fn.Name == "arg" || fn.Name == "narg") { v.err = sqlerr.FunctionNotFound("sqlc." + fn.Name) return nil } From a19061a81bdfac660f405f178b9ed62aa498fda9 Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Sat, 30 Apr 2022 14:46:19 -0700 Subject: [PATCH 10/11] add end-to-end test for sqlc.narg --- .../endtoend/testdata/sqlc_narg/mysql/query.sql | 13 +++++++++++++ .../endtoend/testdata/sqlc_narg/mysql/sqlc.json | 12 ++++++++++++ .../testdata/sqlc_narg/postgresql/pgx/query.sql | 13 +++++++++++++ .../testdata/sqlc_narg/postgresql/pgx/sqlc.json | 13 +++++++++++++ .../testdata/sqlc_narg/postgresql/stdlib/query.sql | 13 +++++++++++++ .../testdata/sqlc_narg/postgresql/stdlib/sqlc.json | 12 ++++++++++++ 6 files changed, 76 insertions(+) create mode 100644 internal/endtoend/testdata/sqlc_narg/mysql/query.sql create mode 100644 internal/endtoend/testdata/sqlc_narg/mysql/sqlc.json create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/pgx/query.sql create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/pgx/sqlc.json create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/query.sql create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/sqlc.json diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/query.sql b/internal/endtoend/testdata/sqlc_narg/mysql/query.sql new file mode 100644 index 0000000000..634830cbdf --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/query.sql @@ -0,0 +1,13 @@ +CREATE TABLE foo (bar text not null, maybe_bar text); + +-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg(bar); + +-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg(maybe_bar); + +-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg('bar'); + +-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg('maybe_bar'); diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/sqlc.json b/internal/endtoend/testdata/sqlc_narg/mysql/sqlc.json new file mode 100644 index 0000000000..0657f4db83 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "mysql", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/query.sql b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/query.sql new file mode 100644 index 0000000000..634830cbdf --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/query.sql @@ -0,0 +1,13 @@ +CREATE TABLE foo (bar text not null, maybe_bar text); + +-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg(bar); + +-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg(maybe_bar); + +-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg('bar'); + +-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg('maybe_bar'); diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/query.sql b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..634830cbdf --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/query.sql @@ -0,0 +1,13 @@ +CREATE TABLE foo (bar text not null, maybe_bar text); + +-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg(bar); + +-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg(maybe_bar); + +-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = sqlc.narg('bar'); + +-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = sqlc.narg('maybe_bar'); diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..de427d069f --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "engine": "postgresql", + "path": "go", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} From 8a9dfab07227263af03cf4ec675a95f330dc7004 Mon Sep 17 00:00:00 2001 From: Steven Kabbes Date: Sat, 30 Apr 2022 14:46:43 -0700 Subject: [PATCH 11/11] commit generated output for sqlc.narg --- .../testdata/sqlc_narg/mysql/go/db.go | 31 +++++ .../testdata/sqlc_narg/mysql/go/models.go | 14 +++ .../testdata/sqlc_narg/mysql/go/query.sql.go | 119 ++++++++++++++++++ .../sqlc_narg/postgresql/pgx/go/db.go | 32 +++++ .../sqlc_narg/postgresql/pgx/go/models.go | 14 +++ .../sqlc_narg/postgresql/pgx/go/query.sql.go | 107 ++++++++++++++++ .../sqlc_narg/postgresql/stdlib/go/db.go | 31 +++++ .../sqlc_narg/postgresql/stdlib/go/models.go | 14 +++ .../postgresql/stdlib/go/query.sql.go | 119 ++++++++++++++++++ 9 files changed, 481 insertions(+) create mode 100644 internal/endtoend/testdata/sqlc_narg/mysql/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_narg/mysql/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_narg/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/query.sql.go create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/db.go create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/models.go create mode 100644 internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/query.sql.go diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/go/db.go b/internal/endtoend/testdata/sqlc_narg/mysql/go/db.go new file mode 100644 index 0000000000..36ef5f4f45 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/go/models.go b/internal/endtoend/testdata/sqlc_narg/mysql/go/models.go new file mode 100644 index 0000000000..faee232b20 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + Bar string + MaybeBar sql.NullString +} diff --git a/internal/endtoend/testdata/sqlc_narg/mysql/go/query.sql.go b/internal/endtoend/testdata/sqlc_narg/mysql/go/query.sql.go new file mode 100644 index 0000000000..db493107fc --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/mysql/go/query.sql.go @@ -0,0 +1,119 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const identOnNonNullable = `-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = ? +` + +func (q *Queries) IdentOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, identOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const identOnNullable = `-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = ? +` + +func (q *Queries) IdentOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, identOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNonNullable = `-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = ? +` + +func (q *Queries) StringOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, stringOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNullable = `-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = ? +` + +func (q *Queries) StringOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, stringOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/db.go b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..b0157bd009 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/models.go b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..faee232b20 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + Bar string + MaybeBar sql.NullString +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..80509257f8 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/pgx/go/query.sql.go @@ -0,0 +1,107 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const identOnNonNullable = `-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) IdentOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.Query(ctx, identOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const identOnNullable = `-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) IdentOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.Query(ctx, identOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNonNullable = `-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) StringOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.Query(ctx, stringOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNullable = `-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) StringOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.Query(ctx, stringOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..36ef5f4f45 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..faee232b20 --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/models.go @@ -0,0 +1,14 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + Bar string + MaybeBar sql.NullString +} diff --git a/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..2939df932e --- /dev/null +++ b/internal/endtoend/testdata/sqlc_narg/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,119 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.13.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const identOnNonNullable = `-- name: IdentOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) IdentOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, identOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const identOnNullable = `-- name: IdentOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) IdentOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, identOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNonNullable = `-- name: StringOnNonNullable :many +SELECT bar FROM foo WHERE bar = $1 +` + +func (q *Queries) StringOnNonNullable(ctx context.Context, bar sql.NullString) ([]string, error) { + rows, err := q.db.QueryContext(ctx, stringOnNonNullable, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var bar string + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const stringOnNullable = `-- name: StringOnNullable :many +SELECT maybe_bar FROM foo WHERE maybe_bar = $1 +` + +func (q *Queries) StringOnNullable(ctx context.Context, maybeBar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, stringOnNullable, maybeBar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var maybe_bar sql.NullString + if err := rows.Scan(&maybe_bar); err != nil { + return nil, err + } + items = append(items, maybe_bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +}