From 431792038e0bc6cd03d510e731406b9478438b9e Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Wed, 15 Jan 2020 15:33:07 -0800 Subject: [PATCH 01/11] examples: Tag the authors database test --- examples/authors/db_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/authors/db_test.go b/examples/authors/db_test.go index 6a4f8d1529..3efcbd7185 100644 --- a/examples/authors/db_test.go +++ b/examples/authors/db_test.go @@ -1,3 +1,5 @@ +// +build examples + package authors import ( From acc935a498b3c819944ea9257b3d3826903043f4 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Thu, 16 Jan 2020 12:58:39 -0800 Subject: [PATCH 02/11] abandon ship for now --- internal/dinosql/astutil.go | 176 ++++++++++++++++++++++++++ internal/dinosql/astutil_test.go | 39 ++++++ internal/dinosql/checks.go | 20 +++ internal/dinosql/checks_test.go | 4 + internal/dinosql/parser.go | 26 ++++ internal/dinosql/query_test.go | 18 ++- internal/dinosql/rewrite.go | 34 +++++ internal/mysql/example/db.go | 29 +++++ internal/mysql/example/models.go | 37 ++++++ internal/mysql/example/queries.sql.go | 100 +++++++++++++++ internal/named/named.go | 133 +++++++++++++++++++ internal/named/named_test.go | 121 ++++++++++++++++++ internal/pg/catalog.go | 1 + internal/pg/sqlc.go | 24 ++++ 14 files changed, 761 insertions(+), 1 deletion(-) create mode 100644 internal/dinosql/astutil.go create mode 100644 internal/dinosql/astutil_test.go create mode 100644 internal/dinosql/rewrite.go create mode 100644 internal/mysql/example/db.go create mode 100644 internal/mysql/example/models.go create mode 100644 internal/mysql/example/queries.sql.go create mode 100644 internal/named/named.go create mode 100644 internal/named/named_test.go create mode 100644 internal/pg/sqlc.go diff --git a/internal/dinosql/astutil.go b/internal/dinosql/astutil.go new file mode 100644 index 0000000000..95b980086a --- /dev/null +++ b/internal/dinosql/astutil.go @@ -0,0 +1,176 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dinosql + +import ( + "fmt" + "reflect" + + nodes "github.com/lfittl/pg_query_go/nodes" +) + +// An ApplyFunc is invoked by Apply for each node n, even if n is nil, +// before and/or after the node's children, using a Cursor describing +// the current node and providing operations on it. +// +// The return value of ApplyFunc controls the syntax tree traversal. +// See Apply for details. +type ApplyFunc func(*Cursor) bool + +// Apply traverses a syntax tree recursively, starting with root, +// and calling pre and post for each node as described below. +// Apply returns the syntax tree, possibly modified. +// +// If pre is not nil, it is called for each node before the node's +// children are traversed (pre-order). If pre returns false, no +// children are traversed, and post is not called for that node. +// +// If post is not nil, and a prior call of pre didn't return false, +// post is called for each node after its children are traversed +// (post-order). If post returns false, traversal is terminated and +// Apply returns immediately. +// +// Only fields that refer to AST nodes are considered children; +// i.e., token.Pos, Scopes, Objects, and fields of basic types +// (strings, etc.) are ignored. +// +// Children are traversed in the order in which they appear in the +// respective node's struct definition. A package's files are +// traversed in the filenames' alphabetical order. +// +func Apply(root nodes.Node, pre, post ApplyFunc) (result nodes.Node) { + parent := &struct{ nodes.Node }{root} + a := &application{pre: pre, post: post} + return a.apply(parent, "Node", nil, root) +} + +var abort = new(int) // singleton, to signal termination of Apply + +// A Cursor describes a node encountered during Apply. +// Information about the node and its parent is available +// from the Node, Parent, Name, and Index methods. +// +// If p is a variable of type and value of the current parent node +// c.Parent(), and f is the field identifier with name c.Name(), +// the following invariants hold: +// +// p.f == c.Node() if c.Index() < 0 +// p.f[c.Index()] == c.Node() if c.Index() >= 0 +// +// The methods Replace, Delete, InsertBefore, and InsertAfter +// can be used to change the AST without disrupting Apply. +type Cursor struct { + parent nodes.Node + name string + iter *iterator // valid if non-nil + node nodes.Node +} + +// Node returns the current Node. +func (c *Cursor) Node() nodes.Node { return c.node } + +// Parent returns the parent of the current Node. +func (c *Cursor) Parent() nodes.Node { return c.parent } + +// Name returns the name of the parent Node field that contains the current Node. +// If the parent is a *ast.Package and the current Node is a *ast.File, Name returns +// the filename for the current Node. +func (c *Cursor) Name() string { return c.name } + +// Index reports the index >= 0 of the current Node in the slice of Nodes that +// contains it, or a value < 0 if the current Node is not part of a slice. +// The index of the current node changes if InsertBefore is called while +// processing the current node. +func (c *Cursor) Index() int { + if c.iter != nil { + return c.iter.index + } + return -1 +} + +// field returns the current node's parent field value. +func (c *Cursor) field() reflect.Value { + return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) +} + +// Replace replaces the current Node with n. +// The replacement node is not walked by Apply. +func (c *Cursor) Replace(n nodes.Node) { + c.node = n +} + +// application carries all the shared data so we can pass it around cheaply. +type application struct { + pre, post ApplyFunc + cursor Cursor + iter iterator +} + +func (a *application) apply(parent nodes.Node, name string, iter *iterator, n nodes.Node) nodes.Node { + // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead + cursor := Cursor{ + parent: parent, + name: name, + iter: iter, + node: n, + } + + if a.pre != nil && !a.pre(&cursor) { + return cursor.node + } + + // walk children + // (the order of the cases matches the order of the corresponding node types in go/ast) + switch n := n.(type) { + case nil: + // nothing to do + + case nodes.RawStmt: + n.Stmt = a.apply(n, "Stmt", nil, n.Stmt) + cursor.node = n + + case nodes.SelectStmt: + n.TargetList = a.apply(n, "TargetList", nil, n.TargetList) + cursor.node = n + + default: + panic(fmt.Sprintf("Apply: unexpected node type %T", n)) + } + + if a.post != nil && !a.post(&cursor) { + panic(abort) + } + + return cursor.node +} + +// An iterator controls iteration over a slice of nodes. +type iterator struct { + index, step int +} + +func (a *application) applyList(parent nodes.Node, name string) { + // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead + saved := a.iter + a.iter.index = 0 + for { + // must reload parent.name each time, since cursor modifications might change it + v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) + if a.iter.index >= v.Len() { + break + } + + // element x may be nil in a bad AST - be cautious + var x nodes.Node + if e := v.Index(a.iter.index); e.IsValid() { + x = e.Interface().(nodes.Node) + } + + a.iter.step = 1 + a.apply(parent, name, &a.iter, x) + a.iter.index += a.iter.step + } + a.iter = saved +} diff --git a/internal/dinosql/astutil_test.go b/internal/dinosql/astutil_test.go new file mode 100644 index 0000000000..5e04807b2e --- /dev/null +++ b/internal/dinosql/astutil_test.go @@ -0,0 +1,39 @@ +package dinosql + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + pg "github.com/lfittl/pg_query_go" + nodes "github.com/lfittl/pg_query_go/nodes" +) + +func TestApply(t *testing.T) { + input, err := pg.Parse("SELECT sqlc.arg(name)") + if err != nil { + t.Fatal(err) + } + output, err := pg.Parse("SELECT $1") + if err != nil { + t.Fatal(err) + } + + // spew.Dump(input.Statements[0]) + + expect := output.Statements[0] + actual := Apply(input.Statements[0], func(cr *Cursor) bool { + fun, ok := cr.Node().(nodes.FuncCall) + if !ok { + return true + } + if join(fun.Funcname, ".") == "sqlc.arg" { + cr.Replace(nodes.ParamRef{Number: 1}) + return false + } + return true + }, nil) + + if diff := cmp.Diff(expect, actual); diff != "" { + t.Errorf("rewrite mismatch:\n%s", diff) + } +} diff --git a/internal/dinosql/checks.go b/internal/dinosql/checks.go index 836f25ad94..cd959e46bb 100644 --- a/internal/dinosql/checks.go +++ b/internal/dinosql/checks.go @@ -120,3 +120,23 @@ func validateInsertStmt(stmt nodes.InsertStmt) error { } return nil } + +// A query can either use named parameters (sqlc.arg(param)) or positional +// parameters ($1), but not both +func validateParamStyle(n nodes.Node) error { + positional := search(n, func(node nodes.Node) bool { + _, ok := node.(nodes.ParamRef) + return ok + }) + named := search(n, func(node nodes.Node) bool { + fun, ok := node.(nodes.FuncCall) + return ok && join(fun.Funcname, ".") == "sqlc.arg" + }) + if len(named.Items) > 0 && len(positional.Items) > 0 { + return pg.Error{ + Code: "", // TODO: Pick a new error code + Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg)", + } + } + return nil +} diff --git a/internal/dinosql/checks_test.go b/internal/dinosql/checks_test.go index 96108c15cd..1f399085d4 100644 --- a/internal/dinosql/checks_test.go +++ b/internal/dinosql/checks_test.go @@ -66,6 +66,10 @@ func TestParserErrors(t *testing.T) { Location: 7, }, }, + { + "SELECT foo FROM bar WHERE baz = $1 AND bat = sqlc.arg(named);", + pg.Error{Code: "", Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg)"}, + }, } { test := tc t.Run(test.query, func(t *testing.T) { diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 4cd3823cd2..aade2d2fac 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -231,6 +231,11 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) { continue } source := string(blob) + // source, _, err := named.CompileNamedQuery(blob, named.DOLLAR) + // if err != nil { + // merr.Add(filename, "", 0, err) + // continue + // } tree, err := pg.Parse(source) if err != nil { merr.Add(filename, source, 0, err) @@ -420,6 +425,9 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err := validateParamRef(stmt); err != nil { return nil, err } + if err := validateParamStyle(stmt); err != nil { + return nil, err + } raw, ok := stmt.(nodes.RawStmt) if !ok { return nil, errors.New("node is not a statement") @@ -450,6 +458,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err := validateCmd(raw.Stmt, name, cmd); err != nil { return nil, err } + + // Query manipulation + raw, err = rewriteNamedParameters(raw) + if err != nil { + return nil, err + } + rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) params, err := resolveCatalogRefs(c, rvs, refs) @@ -962,6 +977,9 @@ type paramRef struct { parent nodes.Node rv *nodes.RangeVar ref nodes.ParamRef + + // HACK + name string } type paramSearch struct { @@ -997,6 +1015,11 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { switch n := node.(type) { case nodes.A_Expr: + if join(n.Name, "-") == "@" && n.Lexpr == nil { + param := nodes.ParamRef{Number: 1} + p.refs[1] = paramRef{parent: p.parent, rv: p.rangeVar, name: "slug", ref: param} + return nil + } p.parent = node case nodes.FuncCall: @@ -1235,6 +1258,9 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( for _, table := range search { if c, ok := typeMap[table.Schema][table.Rel][key]; ok { found += 1 + if ref.name != "" { + key = ref.name + } a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index c80b1d85e6..249c71c013 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -884,6 +884,22 @@ func TestQueries(t *testing.T) { }, }, }, + { + "named_parameter", + ` + CREATE TABLE foo (name text not null); + SELECT name FROM foo WHERE name = sqlc.arg(slug); + `, + Query{ + SQL: "SELECT name FROM foo WHERE name = $1;", + Columns: []core.Column{ + {Table: public("foo"), Name: "name", DataType: "text", NotNull: true}, + }, + Params: []Parameter{ + {1, core.Column{Table: public("foo"), Name: "slug", DataType: "text", NotNull: true}}, + }, + }, + }, } { test := tc t.Run(test.name, func(t *testing.T) { @@ -928,7 +944,7 @@ func TestComparisonOperators(t *testing.T) { } func TestUnknownFunctions(t *testing.T) { - stmt := ` + stmt := ` CREATE TABLE foo (id text not null); -- name: ListFoos :one SELECT id FROM foo WHERE id = frobnicate($1); diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go new file mode 100644 index 0000000000..8dcf04a84f --- /dev/null +++ b/internal/dinosql/rewrite.go @@ -0,0 +1,34 @@ +package dinosql + +import ( + nodes "github.com/lfittl/pg_query_go/nodes" +) + +type stringWalker struct { + String string +} + +func (s *stringWalker) Visit(node nodes.Node) Visitor { + if n, ok := node.(nodes.String); ok { + s.String += n.Str + } + return s +} + +func flatten(root nodes.Node) string { + sw := &stringWalker{} + Walk(sw, root) + return sw.String +} + +func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, error) { + named := search(raw, func(node nodes.Node) bool { + fun, ok := node.(nodes.FuncCall) + return ok && join(fun.Funcname, ".") == "sqlc.arg" + }) + for _, np := range named.Items { + fun := np.(nodes.FuncCall) + flatten(fun.Args) + } + return raw, nil +} diff --git a/internal/mysql/example/db.go b/internal/mysql/example/db.go new file mode 100644 index 0000000000..3df4e8545c --- /dev/null +++ b/internal/mysql/example/db.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. + +package teachersDB + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/mysql/example/models.go b/internal/mysql/example/models.go new file mode 100644 index 0000000000..7655ff580f --- /dev/null +++ b/internal/mysql/example/models.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. + +package teachersDB + +import ( + "database/sql" +) + +type DepartmentType string + +const ( + English DepartmentType = "English" + Math DepartmentType = "Math" +) + +func (e *DepartmentType) Scan(src interface{}) error { + *e = DepartmentType(src.([]byte)) + return nil +} + +type Teacher struct { + ID int `json:"id"` + FirstName sql.NullString `json:"first_name"` + LastName sql.NullString `json:"last_name"` + SchoolID int `json:"school_id"` + ClassID int `json:"class_id"` + SchoolLat sql.NullFloat64 `json:"school_lat"` + SchoolLng sql.NullFloat64 `json:"school_lng"` + Department DepartmentType `json:"department"` +} + +type Student struct { + ID int `json:"id"` + ClassID int `json:"class_id"` + FirstName sql.NullString `json:"first_name"` + LastName sql.NullString `json:"last_name"` +} diff --git a/internal/mysql/example/queries.sql.go b/internal/mysql/example/queries.sql.go new file mode 100644 index 0000000000..1d1a3570a6 --- /dev/null +++ b/internal/mysql/example/queries.sql.go @@ -0,0 +1,100 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: queries.sql + +package teachersDB + +import ( + "context" + "database/sql" +) + +const getSomeTeachers = `-- name: GetSomeTeachers :one +select school_id, id from teachers where school_lng > ? and school_lat < ? +` + +type GetSomeTeachersParams struct { + SchoolLng sql.NullFloat64 `json:"school_lng"` + SchoolLat sql.NullFloat64 `json:"school_lat"` +} + +type GetSomeTeachersRow struct { + SchoolID int `json:"school_id"` + ID int `json:"id"` +} + +func (q *Queries) GetSomeTeachers(ctx context.Context, arg GetSomeTeachersParams) (GetSomeTeachersRow, error) { + row := q.db.QueryRowContext(ctx, getSomeTeachers, arg.SchoolLng, arg.SchoolLat) + var i GetSomeTeachersRow + err := row.Scan(&i.SchoolID, &i.ID) + return i, err +} + +const getStudentsTeacher = `-- name: GetStudentsTeacher :one +select students.first_name, students.last_name, teachers.first_name as teacherFirstName, teachers.id as teacher_id from students left join teachers on teachers.class_id = students.class_id where students.id = ? +` + +type GetStudentsTeacherRow struct { + FirstName sql.NullString `json:"first_name"` + LastName sql.NullString `json:"last_name"` + TeacherFirstName sql.NullString `json:"teacherFirstName"` + TeacherID sql.NullInt64 `json:"teacher_id"` +} + +func (q *Queries) GetStudentsTeacher(ctx context.Context, studentID int) (GetStudentsTeacherRow, error) { + row := q.db.QueryRowContext(ctx, getStudentsTeacher, studentID) + var i GetStudentsTeacherRow + err := row.Scan( + &i.FirstName, + &i.LastName, + &i.TeacherFirstName, + &i.TeacherID, + ) + return i, err +} + +const getTeachersByID = `-- name: GetTeachersByID :one +select id, first_name, last_name, school_id, class_id, school_lat, school_lng, department from teachers where id = ? +` + +type GetTeachersByIDRow struct { + ID int `json:"id"` + FirstName sql.NullString `json:"first_name"` + LastName sql.NullString `json:"last_name"` + SchoolID int `json:"school_id"` + ClassID int `json:"class_id"` + SchoolLat sql.NullFloat64 `json:"school_lat"` + SchoolLng sql.NullFloat64 `json:"school_lng"` + Department DepartmentType `json:"department"` +} + +func (q *Queries) GetTeachersByID(ctx context.Context, id int) (GetTeachersByIDRow, error) { + row := q.db.QueryRowContext(ctx, getTeachersByID, id) + var i GetTeachersByIDRow + err := row.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.SchoolID, + &i.ClassID, + &i.SchoolLat, + &i.SchoolLng, + &i.Department, + ) + return i, err +} + +const teachersByID = `-- name: TeachersByID :one +select id, school_lat from teachers where id = ? limit 10 +` + +type TeachersByIDRow struct { + ID int `json:"id"` + SchoolLat sql.NullFloat64 `json:"school_lat"` +} + +func (q *Queries) TeachersByID(ctx context.Context, id int) (TeachersByIDRow, error) { + row := q.db.QueryRowContext(ctx, teachersByID, id) + var i TeachersByIDRow + err := row.Scan(&i.ID, &i.SchoolLat) + return i, err +} diff --git a/internal/named/named.go b/internal/named/named.go new file mode 100644 index 0000000000..2cb7253220 --- /dev/null +++ b/internal/named/named.go @@ -0,0 +1,133 @@ +package named + +// Copyright (c) 2013, Jason Moiron +// +// Permission is hereby granted, free of charge, to any person +// obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, +// copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +// OTHER DEALINGS IN THE SOFTWARE. + +import ( + "errors" + "strconv" + "unicode" +) + +// Bindvar types supported by Rebind, BindMap and BindStruct. +const ( + UNKNOWN = iota + QUESTION + DOLLAR + NAMED + AT +) + +// -- Compilation of Named Queries + +// Allow digits and letters in bind params; additionally runes are +// checked against underscores, meaning that bind params can have be +// alphanumeric with underscores. Mind the difference between unicode +// digits and numbers, where '5' is a digit but '五' is not. +var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} + +// FIXME: this function isn't safe for unicode named params, as a failing test +// can testify. This is not a regression but a failure of the original code +// as well. It should be modified to range over runes in a string rather than +// bytes, even though this is less convenient and slower. Hopefully the +// addition of the prepared NamedStmt (which will only do this once) will make +// up for the slightly slower ad-hoc NamedExec/NamedQuery. + +// compile a NamedQuery into an unbound query (using the '?' bindvar) and +// a list of names. +func CompileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { + names = make([]string, 0, 10) + rebound := make([]byte, 0, len(qs)) + + inName := false + last := len(qs) - 1 + currentVar := 1 + name := make([]byte, 0, 10) + + for i, b := range qs { + // a ':' while we're in a name is an error + if b == ':' { + // if this is the second ':' in a '::' escape sequence, append a ':' + if inName && i > 0 && qs[i-1] == ':' { + rebound = append(rebound, ':') + inName = false + continue + } else if inName { + err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) + return query, names, err + } + inName = true + name = []byte{} + } else if inName && i > 0 && b == '=' && len(name) == 0 { + rebound = append(rebound, ':', '=') + inName = false + continue + // if we're in a name, and this is an allowed character, continue + } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { + // append the byte to the name if we are in a name and not on the last byte + name = append(name, b) + // if we're in a name and it's not an allowed character, the name is done + } else if inName { + inName = false + // if this is the final byte of the string and it is part of the name, then + // make sure to add it to the name + if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { + name = append(name, b) + } + // add the string representation to the names list + names = append(names, string(name)) + // add a proper bindvar for the bindType + switch bindType { + // oracle only supports named type bind vars even for positional + case NAMED: + rebound = append(rebound, ':') + rebound = append(rebound, name...) + case QUESTION, UNKNOWN: + rebound = append(rebound, '?') + case DOLLAR: + rebound = append(rebound, '$') + for _, b := range strconv.Itoa(currentVar) { + rebound = append(rebound, byte(b)) + } + currentVar++ + case AT: + rebound = append(rebound, '@', 'p') + for _, b := range strconv.Itoa(currentVar) { + rebound = append(rebound, byte(b)) + } + currentVar++ + } + // add this byte to string unless it was not part of the name + if i != last { + rebound = append(rebound, b) + } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { + rebound = append(rebound, b) + } + } else { + // this is a normal byte and should just go onto the rebound query + rebound = append(rebound, b) + } + } + + return string(rebound), names, err +} diff --git a/internal/named/named_test.go b/internal/named/named_test.go new file mode 100644 index 0000000000..9c0aaff715 --- /dev/null +++ b/internal/named/named_test.go @@ -0,0 +1,121 @@ +package named + +// Copyright (c) 2013, Jason Moiron +// +// Permission is hereby granted, free of charge, to any person +// obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without +// restriction, including without limitation the rights to use, +// copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +// OTHER DEALINGS IN THE SOFTWARE. + +import ( + "testing" +) + +func TestCompileQuery(t *testing.T) { + table := []struct { + Q, R, D, T, N string + V []string + }{ + // basic test for named parameters, invalid char ',' terminating + { + Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, + R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, + D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, + T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`, + N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, + V: []string{"name", "age", "first", "last"}, + }, + // This query tests a named parameter ending the string as well as numbers + { + Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, + R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, + D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, + T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`, + N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, + V: []string{"name1", "name2"}, + }, + { + Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, + R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, + D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, + T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`, + N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, + V: []string{"name1", "name2"}, + }, + { + Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, + R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, + D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, + T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`, + N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, + V: []string{"first_name", "last_name"}, + }, + { + Q: `SELECT @name := "name", :age, :first, :last`, + R: `SELECT @name := "name", ?, ?, ?`, + D: `SELECT @name := "name", $1, $2, $3`, + N: `SELECT @name := "name", :age, :first, :last`, + T: `SELECT @name := "name", @p1, @p2, @p3`, + V: []string{"age", "first", "last"}, + }, + /* This unicode awareness test sadly fails, because of our byte-wise worldview. + * We could certainly iterate by Rune instead, though it's a great deal slower, + * it's probably the RightWay(tm) + { + Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`, + R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, + D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, + N: []string{"name", "age", "first", "last"}, + }, + */ + } + + for _, test := range table { + qr, names, err := CompileNamedQuery([]byte(test.Q), QUESTION) + if err != nil { + t.Error(err) + } + if qr != test.R { + t.Errorf("expected %s, got %s", test.R, qr) + } + if len(names) != len(test.V) { + t.Errorf("expected %#v, got %#v", test.V, names) + } else { + for i, name := range names { + if name != test.V[i] { + t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name) + } + } + } + qd, _, _ := CompileNamedQuery([]byte(test.Q), DOLLAR) + if qd != test.D { + t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) + } + + qt, _, _ := CompileNamedQuery([]byte(test.Q), AT) + if qt != test.T { + t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt) + } + + qq, _, _ := CompileNamedQuery([]byte(test.Q), NAMED) + if qq != test.N { + t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) + } + } +} diff --git a/internal/pg/catalog.go b/internal/pg/catalog.go index f4eec2d23a..e00836a4c4 100644 --- a/internal/pg/catalog.go +++ b/internal/pg/catalog.go @@ -5,6 +5,7 @@ func NewCatalog() Catalog { Schemas: map[string]Schema{ "public": NewSchema(), "pg_catalog": pgCatalog(), + "sqlc": internalSchema(), // Likewise, the current session's temporary-table schema, pg_temp_nnn, is // always searched if it exists. It can be explicitly listed in the path by // using the alias pg_temp. If it is not listed in the path then it is diff --git a/internal/pg/sqlc.go b/internal/pg/sqlc.go new file mode 100644 index 0000000000..4c6fc57460 --- /dev/null +++ b/internal/pg/sqlc.go @@ -0,0 +1,24 @@ +package pg + +func internalSchema() Schema { + s := NewSchema() + s.Name = "sqlc" + fs := []Function{ + { + Name: "arg", + Desc: "Named argumented placeholder", + ReturnType: "void", + Arguments: []Argument{ + { + Name: "name", + DataType: "id", + }, + }, + }, + } + s.Funcs = make(map[string][]Function, len(fs)) + for _, f := range fs { + s.Funcs[f.Name] = append(s.Funcs[f.Name], f) + } + return s +} From 4ec91680848d068fa57efe176ef738392c7eb3bc Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Thu, 16 Jan 2020 20:45:59 -0800 Subject: [PATCH 03/11] getting closer --- internal/dinosql/astutil.go | 176 --- internal/dinosql/checks.go | 10 +- internal/dinosql/parser.go | 21 +- internal/dinosql/rewrite.go | 6 +- internal/postgresql/ast/astutil.go | 1334 +++++++++++++++++ .../ast}/astutil_test.go | 12 +- internal/postgresql/ast/join.go | 17 + internal/{dinosql => postgresql/ast}/soup.go | 2 +- 8 files changed, 1380 insertions(+), 198 deletions(-) delete mode 100644 internal/dinosql/astutil.go create mode 100644 internal/postgresql/ast/astutil.go rename internal/{dinosql => postgresql/ast}/astutil_test.go (82%) create mode 100644 internal/postgresql/ast/join.go rename internal/{dinosql => postgresql/ast}/soup.go (99%) diff --git a/internal/dinosql/astutil.go b/internal/dinosql/astutil.go deleted file mode 100644 index 95b980086a..0000000000 --- a/internal/dinosql/astutil.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package dinosql - -import ( - "fmt" - "reflect" - - nodes "github.com/lfittl/pg_query_go/nodes" -) - -// An ApplyFunc is invoked by Apply for each node n, even if n is nil, -// before and/or after the node's children, using a Cursor describing -// the current node and providing operations on it. -// -// The return value of ApplyFunc controls the syntax tree traversal. -// See Apply for details. -type ApplyFunc func(*Cursor) bool - -// Apply traverses a syntax tree recursively, starting with root, -// and calling pre and post for each node as described below. -// Apply returns the syntax tree, possibly modified. -// -// If pre is not nil, it is called for each node before the node's -// children are traversed (pre-order). If pre returns false, no -// children are traversed, and post is not called for that node. -// -// If post is not nil, and a prior call of pre didn't return false, -// post is called for each node after its children are traversed -// (post-order). If post returns false, traversal is terminated and -// Apply returns immediately. -// -// Only fields that refer to AST nodes are considered children; -// i.e., token.Pos, Scopes, Objects, and fields of basic types -// (strings, etc.) are ignored. -// -// Children are traversed in the order in which they appear in the -// respective node's struct definition. A package's files are -// traversed in the filenames' alphabetical order. -// -func Apply(root nodes.Node, pre, post ApplyFunc) (result nodes.Node) { - parent := &struct{ nodes.Node }{root} - a := &application{pre: pre, post: post} - return a.apply(parent, "Node", nil, root) -} - -var abort = new(int) // singleton, to signal termination of Apply - -// A Cursor describes a node encountered during Apply. -// Information about the node and its parent is available -// from the Node, Parent, Name, and Index methods. -// -// If p is a variable of type and value of the current parent node -// c.Parent(), and f is the field identifier with name c.Name(), -// the following invariants hold: -// -// p.f == c.Node() if c.Index() < 0 -// p.f[c.Index()] == c.Node() if c.Index() >= 0 -// -// The methods Replace, Delete, InsertBefore, and InsertAfter -// can be used to change the AST without disrupting Apply. -type Cursor struct { - parent nodes.Node - name string - iter *iterator // valid if non-nil - node nodes.Node -} - -// Node returns the current Node. -func (c *Cursor) Node() nodes.Node { return c.node } - -// Parent returns the parent of the current Node. -func (c *Cursor) Parent() nodes.Node { return c.parent } - -// Name returns the name of the parent Node field that contains the current Node. -// If the parent is a *ast.Package and the current Node is a *ast.File, Name returns -// the filename for the current Node. -func (c *Cursor) Name() string { return c.name } - -// Index reports the index >= 0 of the current Node in the slice of Nodes that -// contains it, or a value < 0 if the current Node is not part of a slice. -// The index of the current node changes if InsertBefore is called while -// processing the current node. -func (c *Cursor) Index() int { - if c.iter != nil { - return c.iter.index - } - return -1 -} - -// field returns the current node's parent field value. -func (c *Cursor) field() reflect.Value { - return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) -} - -// Replace replaces the current Node with n. -// The replacement node is not walked by Apply. -func (c *Cursor) Replace(n nodes.Node) { - c.node = n -} - -// application carries all the shared data so we can pass it around cheaply. -type application struct { - pre, post ApplyFunc - cursor Cursor - iter iterator -} - -func (a *application) apply(parent nodes.Node, name string, iter *iterator, n nodes.Node) nodes.Node { - // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead - cursor := Cursor{ - parent: parent, - name: name, - iter: iter, - node: n, - } - - if a.pre != nil && !a.pre(&cursor) { - return cursor.node - } - - // walk children - // (the order of the cases matches the order of the corresponding node types in go/ast) - switch n := n.(type) { - case nil: - // nothing to do - - case nodes.RawStmt: - n.Stmt = a.apply(n, "Stmt", nil, n.Stmt) - cursor.node = n - - case nodes.SelectStmt: - n.TargetList = a.apply(n, "TargetList", nil, n.TargetList) - cursor.node = n - - default: - panic(fmt.Sprintf("Apply: unexpected node type %T", n)) - } - - if a.post != nil && !a.post(&cursor) { - panic(abort) - } - - return cursor.node -} - -// An iterator controls iteration over a slice of nodes. -type iterator struct { - index, step int -} - -func (a *application) applyList(parent nodes.Node, name string) { - // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead - saved := a.iter - a.iter.index = 0 - for { - // must reload parent.name each time, since cursor modifications might change it - v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) - if a.iter.index >= v.Len() { - break - } - - // element x may be nil in a bad AST - be cautious - var x nodes.Node - if e := v.Index(a.iter.index); e.IsValid() { - x = e.Interface().(nodes.Node) - } - - a.iter.step = 1 - a.apply(parent, name, &a.iter, x) - a.iter.index += a.iter.step - } - a.iter = saved -} diff --git a/internal/dinosql/checks.go b/internal/dinosql/checks.go index cd959e46bb..15f81c4dbd 100644 --- a/internal/dinosql/checks.go +++ b/internal/dinosql/checks.go @@ -4,16 +4,18 @@ import ( "fmt" "strings" + nodes "github.com/lfittl/pg_query_go/nodes" + "github.com/kyleconroy/sqlc/internal/catalog" "github.com/kyleconroy/sqlc/internal/pg" - nodes "github.com/lfittl/pg_query_go/nodes" + "github.com/kyleconroy/sqlc/internal/postgresql/ast" ) func validateParamRef(n nodes.Node) error { var allrefs []nodes.ParamRef // Find all parameter references - Walk(VisitorFunc(func(node nodes.Node) { + ast.Walk(ast.VisitorFunc(func(node nodes.Node) { switch n := node.(type) { case nodes.ParamRef: allrefs = append(allrefs, n) @@ -41,7 +43,7 @@ type funcCallVisitor struct { err error } -func (v *funcCallVisitor) Visit(node nodes.Node) Visitor { +func (v *funcCallVisitor) Visit(node nodes.Node) ast.Visitor { if v.err != nil { return nil } @@ -91,7 +93,7 @@ func (v *funcCallVisitor) Visit(node nodes.Node) Visitor { func validateFuncCall(c *pg.Catalog, n nodes.Node) error { visitor := funcCallVisitor{catalog: c} - Walk(&visitor, n) + ast.Walk(&visitor, n) return visitor.err } diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index aade2d2fac..2ddcfd6b70 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -12,13 +12,14 @@ import ( "strings" "unicode" - "github.com/kyleconroy/sqlc/internal/catalog" - core "github.com/kyleconroy/sqlc/internal/pg" - "github.com/kyleconroy/sqlc/internal/postgres" - "github.com/davecgh/go-spew/spew" pg "github.com/lfittl/pg_query_go" nodes "github.com/lfittl/pg_query_go/nodes" + + "github.com/kyleconroy/sqlc/internal/catalog" + core "github.com/kyleconroy/sqlc/internal/pg" + "github.com/kyleconroy/sqlc/internal/postgres" + "github.com/kyleconroy/sqlc/internal/postgresql/ast" ) func keepSpew() { @@ -324,13 +325,13 @@ func pluckQuery(source string, n nodes.RawStmt) (string, error) { func rangeVars(root nodes.Node) []nodes.RangeVar { var vars []nodes.RangeVar - find := VisitorFunc(func(node nodes.Node) { + find := ast.VisitorFunc(func(node nodes.Node) { switch n := node.(type) { case nodes.RangeVar: vars = append(vars, n) } }) - Walk(find, root) + ast.Walk(find, root) return vars } @@ -1011,7 +1012,7 @@ type limitOffset struct { nodeImpl } -func (p paramSearch) Visit(node nodes.Node) Visitor { +func (p paramSearch) Visit(node nodes.Node) ast.Visitor { switch n := node.(type) { case nodes.A_Expr: @@ -1113,7 +1114,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { func findParameters(root nodes.Node) []paramRef { v := paramSearch{refs: map[int]paramRef{}} - Walk(v, root) + ast.Walk(v, root) refs := make([]paramRef, 0) for _, r := range v.refs { refs = append(refs, r) @@ -1127,7 +1128,7 @@ type nodeSearch struct { check func(nodes.Node) bool } -func (s *nodeSearch) Visit(node nodes.Node) Visitor { +func (s *nodeSearch) Visit(node nodes.Node) ast.Visitor { if s.check(node) { s.list.Items = append(s.list.Items, node) } @@ -1136,7 +1137,7 @@ func (s *nodeSearch) Visit(node nodes.Node) Visitor { func search(root nodes.Node, f func(nodes.Node) bool) nodes.List { ns := &nodeSearch{check: f} - Walk(ns, root) + ast.Walk(ns, root) return ns.list } diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go index 8dcf04a84f..092c507e58 100644 --- a/internal/dinosql/rewrite.go +++ b/internal/dinosql/rewrite.go @@ -2,13 +2,15 @@ package dinosql import ( nodes "github.com/lfittl/pg_query_go/nodes" + + "github.com/kyleconroy/sqlc/internal/postgresql/ast" ) type stringWalker struct { String string } -func (s *stringWalker) Visit(node nodes.Node) Visitor { +func (s *stringWalker) Visit(node nodes.Node) ast.Visitor { if n, ok := node.(nodes.String); ok { s.String += n.Str } @@ -17,7 +19,7 @@ func (s *stringWalker) Visit(node nodes.Node) Visitor { func flatten(root nodes.Node) string { sw := &stringWalker{} - Walk(sw, root) + ast.Walk(sw, root) return sw.String } diff --git a/internal/postgresql/ast/astutil.go b/internal/postgresql/ast/astutil.go new file mode 100644 index 0000000000..526bc557f2 --- /dev/null +++ b/internal/postgresql/ast/astutil.go @@ -0,0 +1,1334 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ast + +import ( + "fmt" + "reflect" + + "github.com/davecgh/go-spew/spew" + nodes "github.com/lfittl/pg_query_go/nodes" +) + +// An ApplyFunc is invoked by Apply for each node n, even if n is nil, +// before and/or after the node's children, using a Cursor describing +// the current node and providing operations on it. +// +// The return value of ApplyFunc controls the syntax tree traversal. +// See Apply for details. +type ApplyFunc func(*Cursor) bool + +// Apply traverses a syntax tree recursively, starting with root, +// and calling pre and post for each node as described below. +// Apply returns the syntax tree, possibly modified. +// +// If pre is not nil, it is called for each node before the node's +// children are traversed (pre-order). If pre returns false, no +// children are traversed, and post is not called for that node. +// +// If post is not nil, and a prior call of pre didn't return false, +// post is called for each node after its children are traversed +// (post-order). If post returns false, traversal is terminated and +// Apply returns immediately. +// +// Only fields that refer to AST nodes are considered children; +// i.e., token.Pos, Scopes, Objects, and fields of basic types +// (strings, etc.) are ignored. +// +// Children are traversed in the order in which they appear in the +// respective node's struct definition. A package's files are +// traversed in the filenames' alphabetical order. +// +func Apply(root nodes.Node, pre, post ApplyFunc) (result nodes.Node) { + parent := &struct{ nodes.Node }{root} + defer func() { + if r := recover(); r != nil && r != abort { + panic(r) + } + result = parent.Node + }() + a := &application{pre: pre, post: post} + a.apply(parent, "Node", nil, root) + return +} + +var abort = new(int) // singleton, to signal termination of Apply + +// A Cursor describes a node encountered during Apply. +// Information about the node and its parent is available +// from the Node, Parent, Name, and Index methods. +// +// If p is a variable of type and value of the current parent node +// c.Parent(), and f is the field identifier with name c.Name(), +// the following invariants hold: +// +// p.f == c.Node() if c.Index() < 0 +// p.f[c.Index()] == c.Node() if c.Index() >= 0 +// +// The methods Replace, Delete, InsertBefore, and InsertAfter +// can be used to change the AST without disrupting Apply. +type Cursor struct { + parent nodes.Node + name string + iter *iterator // valid if non-nil + node nodes.Node +} + +// Node returns the current Node. +func (c *Cursor) Node() nodes.Node { return c.node } + +// Parent returns the parent of the current Node. +func (c *Cursor) Parent() nodes.Node { return c.parent } + +// Name returns the name of the parent Node field that contains the current Node. +// If the parent is a *ast.Package and the current Node is a *ast.File, Name returns +// the filename for the current Node. +func (c *Cursor) Name() string { return c.name } + +// Index reports the index >= 0 of the current Node in the slice of Nodes that +// contains it, or a value < 0 if the current Node is not part of a slice. +// The index of the current node changes if InsertBefore is called while +// processing the current node. +func (c *Cursor) Index() int { + if c.iter != nil { + return c.iter.index + } + return -1 +} + +// field returns the current node's parent field value. +func (c *Cursor) field() reflect.Value { + return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) +} + +// Replace replaces the current Node with n. +// The replacement node is not walked by Apply. +func (c *Cursor) Replace(n nodes.Node) { + v := c.field() + if i := c.Index(); i >= 0 { + v = v.Index(i) + } + v.Set(reflect.ValueOf(n)) +} + +// application carries all the shared data so we can pass it around cheaply. +type application struct { + pre, post ApplyFunc + cursor Cursor + iter iterator +} + +func (a *application) apply(parent nodes.Node, name string, iter *iterator, node nodes.Node) { + // convert typed nil into untyped nil + if v := reflect.ValueOf(node); v.Kind() == reflect.Ptr && v.IsNil() { + node = nil + } + + // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead + saved := a.cursor + a.cursor.parent = parent + a.cursor.name = name + a.cursor.iter = iter + a.cursor.node = node + + if a.pre != nil && !a.pre(&a.cursor) { + a.cursor = saved + return + } + + // walk children + // (the order of the cases matches the order of the corresponding node types in go/ast) + switch n := node.(type) { + case nil: + // nothing to do + + case nodes.A_ArrayExpr: + a.apply(&n, "Elements", nil, n.Elements) + a.cursor.Replace(n) + + case nodes.A_Const: + a.apply(&n, "Val", nil, n.Val) + a.cursor.Replace(n) + + case nodes.A_Expr: + a.apply(&n, "Name", nil, n.Name) + a.apply(&n, "Lexpr", nil, n.Lexpr) + a.apply(&n, "Rexpr", nil, n.Rexpr) + a.cursor.Replace(n) + + case nodes.A_Indices: + a.apply(&n, "Lidx", nil, n.Lidx) + a.apply(&n, "Uidx", nil, n.Uidx) + + case nodes.A_Indirection: + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Indirection", nil, n.Indirection) + + case nodes.A_Star: + // pass + + case nodes.AccessPriv: + a.apply(&n, "Cols", nil, n.Cols) + + case nodes.Aggref: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Aggargtypes", nil, n.Aggargtypes) + a.apply(&n, "Aggdirectargs", nil, n.Aggdirectargs) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Aggorder", nil, n.Aggorder) + a.apply(&n, "Aggdistinct", nil, n.Aggdistinct) + a.apply(&n, "Aggfilter", nil, n.Aggfilter) + + case nodes.Alias: + a.apply(&n, "Colnames", nil, n.Colnames) + + case nodes.AlterCollationStmt: + a.apply(&n, "Collname", nil, n.Collname) + + case nodes.AlterDatabaseSetStmt: + if n.Setstmt != nil { + a.apply(&n, "Setstmt", nil, *n.Setstmt) + } + + case nodes.AlterDatabaseStmt: + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterDefaultPrivilegesStmt: + if n.Action != nil { + a.apply(&n, "Action", nil, *n.Action) + } + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterDomainStmt: + a.apply(&n, "TypeName", nil, n.TypeName) + a.apply(&n, "Def", nil, n.Def) + + case nodes.AlterEnumStmt: + a.apply(&n, "TypeName", nil, n.TypeName) + + case nodes.AlterEventTrigStmt: + // pass + + case nodes.AlterExtensionContentsStmt: + a.apply(&n, "Object", nil, n.Object) + + case nodes.AlterExtensionStmt: + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterFdwStmt: + a.apply(&n, "FuncOptions", nil, n.FuncOptions) + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterForeignServerStmt: + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterFunctionStmt: + if n.Func != nil { + a.apply(&n, "Func", nil, n.Func) + } + a.apply(&n, "Actions", nil, n.Actions) + + case nodes.AlterObjectDependsStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Object", nil, n.Object) + a.apply(&n, "Extname", nil, n.Extname) + + case nodes.AlterObjectSchemaStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Object", nil, n.Object) + + case nodes.AlterOpFamilyStmt: + a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) + a.apply(&n, "Items", nil, n.Items) + + case nodes.AlterOperatorStmt: + if n.Opername != nil { + a.apply(&n, "Opername", nil, *n.Opername) + } + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterOwnerStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Object", nil, n.Object) + if n.Newowner != nil { + a.apply(&n, "Newowner", nil, *n.Newowner) + } + + case nodes.AlterPolicyStmt: + if n.Table != nil { + a.apply(&n, "Table", nil, *n.Table) + } + a.apply(&n, "Roles", nil, n.Roles) + a.apply(&n, "Qual", nil, n.Qual) + a.apply(&n, "WithCheck", nil, n.WithCheck) + + case nodes.AlterPublicationStmt: + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "Table", nil, n.Tables) + + case nodes.AlterRoleSetStmt: + if n.Role != nil { + a.apply(&n, "Role", nil, *n.Role) + } + a.apply(&n, "Setstmt", nil, n.Setstmt) + + case nodes.AlterRoleStmt: + if n.Role != nil { + a.apply(&n, "Role", nil, *n.Role) + } + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterSeqStmt: + if n.Sequence != nil { + a.apply(&n, "Sequence", nil, *n.Sequence) + } + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterSubscriptionStmt: + a.apply(&n, "Publication", nil, n.Publication) + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterSystemStmt: + a.apply(&n, "Setstmt", nil, n.Setstmt) + + case nodes.AlterTSConfigurationStmt: + a.apply(&n, "Cfgname", nil, n.Cfgname) + a.apply(&n, "Tokentype", nil, n.Tokentype) + a.apply(&n, "Dicts", nil, n.Dicts) + + case nodes.AlterTSDictionaryStmt: + a.apply(&n, "Dictname", nil, n.Dictname) + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterTableCmd: + if n.Newowner != nil { + a.apply(&n, "Newowner", nil, *n.Newowner) + } + a.apply(&n, "Def", nil, n.Def) + + case nodes.AlterTableMoveAllStmt: + a.apply(&n, "Roles", nil, n.Roles) + + case nodes.AlterTableSpaceOptionsStmt: + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlterTableStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Cmds", nil, n.Cmds) + + case nodes.AlterUserMappingStmt: + if n.User != nil { + a.apply(&n, "User", nil, *n.User) + } + a.apply(&n, "Options", nil, n.Options) + + case nodes.AlternativeSubPlan: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Subplans", nil, n.Subplans) + + case nodes.ArrayCoerceExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + + case nodes.ArrayExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Elements", nil, n.Elements) + + case nodes.ArrayRef: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Refupperindexpr", nil, n.Refupperindexpr) + a.apply(&n, "Reflowerindexpr", nil, n.Reflowerindexpr) + a.apply(&n, "Refexpr", nil, n.Refexpr) + a.apply(&n, "Refassgnexpr", nil, n.Refassgnexpr) + + case nodes.BitString: + // pass + + case nodes.BlockIdData: + // pass + + case nodes.BoolExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + + case nodes.BooleanTest: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + + case nodes.CaseExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Defresult", nil, n.Defresult) + + case nodes.CaseTestExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + + case nodes.CaseWhen: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Expr", nil, n.Expr) + a.apply(&n, "Result", nil, n.Result) + + case nodes.CheckPointStmt: + // pass + + case nodes.ClosePortalStmt: + // pass + + case nodes.ClusterStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + + case nodes.CoalesceExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + + case nodes.CoerceToDomain: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + + case nodes.CoerceToDomainValue: + a.apply(&n, "Xpr", nil, n.Xpr) + + case nodes.CoerceViaIO: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + + case nodes.CollateClause: + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Collname", nil, n.Collname) + + case nodes.CollateExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + + case nodes.ColumnDef: + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + a.apply(&n, "RawDefault", nil, n.RawDefault) + a.apply(&n, "CookedDefault", nil, n.CookedDefault) + a.apply(&n, "Constraints", nil, n.Constraints) + a.apply(&n, "Fdwoptions", nil, n.Fdwoptions) + + case nodes.ColumnRef: + a.apply(&n, "Fields", nil, n.Fields) + + case nodes.CommentStmt: + a.apply(&n, "Object", nil, n.Object) + + case nodes.CommonTableExpr: + a.apply(&n, "Aliascolnames", nil, n.Aliascolnames) + a.apply(&n, "Ctequery", nil, n.Ctequery) + a.apply(&n, "Ctecolnames", nil, n.Ctecolnames) + a.apply(&n, "Ctecolcollations", nil, n.Ctecolcollations) + + case nodes.CompositeTypeStmt: + if n.Typevar != nil { + a.apply(&n, "Typevar", nil, *n.Typevar) + } + a.apply(&n, "Coldeflist", nil, n.Coldeflist) + + case nodes.Const: + a.apply(&n, "Xpr", nil, n.Xpr) + + case nodes.Constraint: + a.apply(&n, "RawExpr", nil, n.RawExpr) + a.apply(&n, "Keys", nil, n.Keys) + a.apply(&n, "Exclusions", nil, n.Exclusions) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "WhereClause", nil, n.WhereClause) + if n.Pktable != nil { + a.apply(&n, "Pktable", nil, *n.Pktable) + } + a.apply(&n, "FkAttrs", nil, n.FkAttrs) + a.apply(&n, "PkAttrs", nil, n.PkAttrs) + a.apply(&n, "OldConpfeqop", nil, n.OldConpfeqop) + + case nodes.ConstraintsSetStmt: + a.apply(&n, "Constraints", nil, n.Constraints) + + case nodes.ConvertRowtypeExpr: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + + case nodes.CopyStmt: + if n.Relation != nil { + a.apply(&n, "Relation", nil, *n.Relation) + } + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Attlist", nil, n.Attlist) + a.apply(&n, "Options", nil, n.Options) + + case nodes.CreateAmStmt: + a.apply(&n, "HandlerName", nil, n.HandlerName) + + case nodes.CreateCastStmt: + if n.Sourcetype != nil { + a.apply(&n, "Sourcetype", nil, *n.Sourcetype) + } + if n.Targettype != nil { + a.apply(&n, "Targettype", nil, *n.Targettype) + } + a.apply(&n, "Func", nil, n.Func) + + case nodes.CreateConversionStmt: + a.apply(&n, "ConversionName", nil, n.ConversionName) + a.apply(&n, "Funcname", nil, n.FuncName) + + case nodes.CreateDomainStmt: + a.apply(&n, "Domainname", nil, n.Domainname) + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + if n.CollClause != nil { + a.apply(&n, "CollClause", nil, *n.CollClause) + } + a.apply(&n, "Constraints", nil, n.Constraints) + + case nodes.CreateEnumStmt: + a.apply(&n, "", nil, n.TypeName) + a.apply(&n, "", nil, n.Vals) + + case nodes.CreateEventTrigStmt: + a.apply(&n, "", nil, n.Whenclause) + a.apply(&n, "", nil, n.Funcname) + + case nodes.CreateExtensionStmt: + a.apply(&n, "", nil, n.Options) + + case nodes.CreateFdwStmt: + a.apply(&n, "", nil, n.FuncOptions) + a.apply(&n, "", nil, n.Options) + + case nodes.CreateForeignServerStmt: + a.apply(&n, "", nil, n.Options) + + case nodes.CreateForeignTableStmt: + a.apply(&n, "", nil, n.Base) + a.apply(&n, "", nil, n.Options) + + case nodes.CreateFunctionStmt: + a.apply(&n, "", nil, n.Funcname) + a.apply(&n, "", nil, n.Parameters) + if n.ReturnType != nil { + a.apply(&n, "", nil, *n.ReturnType) + } + a.apply(&n, "", nil, n.Options) + a.apply(&n, "", nil, n.WithClause) + + case nodes.CreateOpClassItem: + a.apply(&n, "", nil, n.Name) + a.apply(&n, "", nil, n.OrderFamily) + a.apply(&n, "", nil, n.ClassArgs) + if n.Storedtype != nil { + a.apply(&n, "", nil, *n.Storedtype) + } + + case nodes.CreateOpClassStmt: + a.apply(&n, "", nil, n.Opclassname) + a.apply(&n, "", nil, n.Opfamilyname) + if n.Datatype != nil { + a.apply(&n, "", nil, *n.Datatype) + } + a.apply(&n, "", nil, n.Items) + + case nodes.CreateOpFamilyStmt: + a.apply(&n, "", nil, n.Opfamilyname) + + case nodes.CreatePLangStmt: + a.apply(&n, "", nil, n.Plhandler) + a.apply(&n, "", nil, n.Plinline) + a.apply(&n, "", nil, n.Plvalidator) + + case nodes.CreatePolicyStmt: + if n.Table != nil { + a.apply(&n, "", nil, *n.Table) + } + a.apply(&n, "", nil, n.Roles) + a.apply(&n, "", nil, n.Qual) + a.apply(&n, "", nil, n.WithCheck) + + case nodes.CreatePublicationStmt: + a.apply(&n, "", nil, n.Options) + a.apply(&n, "", nil, n.Tables) + + case nodes.CreateRangeStmt: + a.apply(&n, "", nil, n.TypeName) + a.apply(&n, "", nil, n.Params) + + case nodes.CreateRoleStmt: + a.apply(&n, "", nil, n.Options) + + case nodes.CreateSchemaStmt: + if n.Authrole != nil { + a.apply(&n, "", nil, *n.Authrole) + } + a.apply(&n, "", nil, n.SchemaElts) + + case nodes.CreateSeqStmt: + if n.Sequence != nil { + a.apply(&n, "", nil, *n.Sequence) + } + a.apply(&n, "", nil, n.Options) + + case nodes.CreateStatsStmt: + a.apply(&n, "", nil, n.Defnames) + a.apply(&n, "", nil, n.StatTypes) + a.apply(&n, "", nil, n.Exprs) + a.apply(&n, "", nil, n.Relations) + + case nodes.CreateStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.TableElts) + a.apply(&n, "", nil, n.InhRelations) + if n.Partbound != nil { + a.apply(&n, "", nil, *n.Partbound) + } + if n.Partspec != nil { + a.apply(&n, "", nil, *n.Partspec) + } + a.apply(&n, "", nil, n.Constraints) + a.apply(&n, "", nil, n.Options) + if n.OfTypename != nil { + a.apply(&n, "", nil, *n.OfTypename) + } + + case nodes.CreateSubscriptionStmt: + a.apply(&n, "", nil, n.Publication) + a.apply(&n, "", nil, n.Options) + + case nodes.CreateTableAsStmt: + a.apply(&n, "", nil, n.Query) + a.apply(&n, "", nil, n.Into) + + case nodes.CreateTableSpaceStmt: + if n.Owner != nil { + a.apply(&n, "", nil, *n.Owner) + } + a.apply(&n, "", nil, n.Options) + + case nodes.CreateTransformStmt: + if n.TypeName != nil { + a.apply(&n, "", nil, *n.TypeName) + } + if n.Fromsql != nil { + a.apply(&n, "", nil, *n.Fromsql) + } + if n.Tosql != nil { + a.apply(&n, "", nil, *n.Tosql) + } + + case nodes.CreateTrigStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.Funcname) + a.apply(&n, "", nil, n.Args) + a.apply(&n, "", nil, n.Columns) + a.apply(&n, "", nil, n.WhenClause) + a.apply(&n, "", nil, n.TransitionRels) + if n.Constrrel != nil { + a.apply(&n, "", nil, *n.Constrrel) + } + + case nodes.CreateUserMappingStmt: + if n.User != nil { + a.apply(&n, "", nil, *n.User) + } + a.apply(&n, "", nil, n.Options) + + case nodes.CreatedbStmt: + a.apply(&n, "", nil, n.Options) + + case nodes.CurrentOfExpr: + a.apply(&n, "", nil, n.Xpr) + + case nodes.DeallocateStmt: + // pass + + case nodes.DeclareCursorStmt: + a.apply(&n, "", nil, n.Query) + + case nodes.DefElem: + a.apply(&n, "", nil, n.Arg) + + case nodes.DefineStmt: + a.apply(&n, "", nil, n.Defnames) + a.apply(&n, "", nil, n.Args) + a.apply(&n, "", nil, n.Definition) + + case nodes.DeleteStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.UsingClause) + a.apply(&n, "", nil, n.WhereClause) + a.apply(&n, "", nil, n.ReturningList) + if n.WithClause != nil { + a.apply(&n, "", nil, *n.WithClause) + } + + case nodes.DiscardStmt: + // pass + + case nodes.DoStmt: + a.apply(&n, "", nil, n.Args) + + case nodes.DropOwnedStmt: + a.apply(&n, "", nil, n.Roles) + + case nodes.DropRoleStmt: + a.apply(&n, "", nil, n.Roles) + + case nodes.DropStmt: + a.apply(&n, "", nil, n.Objects) + + case nodes.DropSubscriptionStmt: + // pass + + case nodes.DropTableSpaceStmt: + // pass + + case nodes.DropUserMappingStmt: + if n.User != nil { + a.apply(&n, "", nil, *n.User) + } + + case nodes.DropdbStmt: + // pass + + case nodes.ExecuteStmt: + a.apply(&n, "", nil, n.Params) + + case nodes.ExplainStmt: + a.apply(&n, "", nil, n.Query) + a.apply(&n, "", nil, n.Options) + + case nodes.Expr: + // pass + + case nodes.FetchStmt: + // pass + + case nodes.FieldSelect: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Arg) + + case nodes.FieldStore: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Arg) + a.apply(&n, "", nil, n.Newvals) + a.apply(&n, "", nil, n.Fieldnums) + + case nodes.Float: + // pass + + case nodes.FromExpr: + a.apply(&n, "", nil, n.Fromlist) + a.apply(&n, "", nil, n.Quals) + + case nodes.FuncCall: + a.apply(&n, "", nil, n.Funcname) + a.apply(&n, "", nil, n.Args) + a.apply(&n, "", nil, n.AggOrder) + a.apply(&n, "", nil, n.AggFilter) + if n.Over != nil { + a.apply(&n, "", nil, *n.Over) + } + + case nodes.FuncExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Args) + + case nodes.FunctionParameter: + if n.ArgType != nil { + a.apply(&n, "", nil, *n.ArgType) + } + a.apply(&n, "", nil, n.Defexpr) + + case nodes.GrantRoleStmt: + a.apply(&n, "", nil, n.GrantedRoles) + a.apply(&n, "", nil, n.GranteeRoles) + if n.Grantor != nil { + a.apply(&n, "", nil, *n.Grantor) + } + + case nodes.GrantStmt: + a.apply(&n, "", nil, n.Objects) + a.apply(&n, "", nil, n.Privileges) + a.apply(&n, "", nil, n.Grantees) + + case nodes.GroupingFunc: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Args) + a.apply(&n, "", nil, n.Refs) + a.apply(&n, "", nil, n.Cols) + + case nodes.GroupingSet: + a.apply(&n, "", nil, n.Content) + + case nodes.ImportForeignSchemaStmt: + a.apply(&n, "", nil, n.TableList) + a.apply(&n, "", nil, n.Options) + + case nodes.IndexElem: + a.apply(&n, "", nil, n.Expr) + a.apply(&n, "", nil, n.Collation) + a.apply(&n, "", nil, n.Opclass) + + case nodes.IndexStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.IndexParams) + a.apply(&n, "", nil, n.Options) + a.apply(&n, "", nil, n.WhereClause) + a.apply(&n, "", nil, n.ExcludeOpNames) + + case nodes.InferClause: + a.apply(&n, "", nil, n.IndexElems) + a.apply(&n, "", nil, n.WhereClause) + + case nodes.InferenceElem: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Expr) + + case nodes.InlineCodeBlock: + // pass + + case nodes.InsertStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.Cols) + a.apply(&n, "", nil, n.SelectStmt) + if n.OnConflictClause != nil { + a.apply(&n, "", nil, *n.OnConflictClause) + } + a.apply(&n, "", nil, n.ReturningList) + if n.WithClause != nil { + a.apply(&n, "", nil, *n.WithClause) + } + + case nodes.Integer: + // pass + + case nodes.IntoClause: + if n.Rel != nil { + a.apply(&n, "", nil, *n.Rel) + } + a.apply(&n, "", nil, n.ColNames) + a.apply(&n, "", nil, n.Options) + a.apply(&n, "", nil, n.ViewQuery) + + case nodes.JoinExpr: + a.apply(&n, "", nil, n.Larg) + a.apply(&n, "", nil, n.Rarg) + a.apply(&n, "", nil, n.UsingClause) + a.apply(&n, "", nil, n.Quals) + if n.Alias != nil { + a.apply(&n, "", nil, *n.Alias) + } + + case nodes.List: + a.applyList(&n, "Items") + spew.Dump(a.cursor) + a.cursor.Replace(n) + + case nodes.ListenStmt: + // pass + + case nodes.LoadStmt: + // pass + + case nodes.LockStmt: + a.apply(&n, "", nil, n.Relations) + + case nodes.LockingClause: + a.apply(&n, "", nil, n.LockedRels) + + case nodes.MinMaxExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Args) + + case nodes.MultiAssignRef: + a.apply(&n, "", nil, n.Source) + + case nodes.NamedArgExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Arg) + + case nodes.NextValueExpr: + a.apply(&n, "", nil, n.Xpr) + + case nodes.NotifyStmt: + // pass + + case nodes.Null: + // pass + + case nodes.NullTest: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Arg) + + case nodes.ObjectWithArgs: + a.apply(&n, "", nil, n.Objname) + a.apply(&n, "", nil, n.Objargs) + + case nodes.OnConflictClause: + if n.Infer != nil { + a.apply(&n, "", nil, *n.Infer) + } + a.apply(&n, "", nil, n.TargetList) + a.apply(&n, "", nil, n.WhereClause) + + case nodes.OnConflictExpr: + a.apply(&n, "", nil, n.ArbiterElems) + a.apply(&n, "", nil, n.ArbiterWhere) + a.apply(&n, "", nil, n.OnConflictSet) + a.apply(&n, "", nil, n.OnConflictWhere) + a.apply(&n, "", nil, n.ExclRelTlist) + + case nodes.OpExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Args) + + case nodes.Param: + a.apply(&n, "", nil, n.Xpr) + + case nodes.ParamExecData: + // pass + + case nodes.ParamExternData: + // pass + + case nodes.ParamListInfoData: + // pass + + case nodes.ParamRef: + // pass + + case nodes.PartitionBoundSpec: + a.apply(&n, "", nil, n.Listdatums) + a.apply(&n, "", nil, n.Lowerdatums) + a.apply(&n, "", nil, n.Upperdatums) + + case nodes.PartitionCmd: + if n.Name != nil { + a.apply(&n, "", nil, *n.Name) + } + if n.Bound != nil { + a.apply(&n, "", nil, *n.Bound) + } + + case nodes.PartitionElem: + a.apply(&n, "", nil, n.Expr) + a.apply(&n, "", nil, n.Collation) + a.apply(&n, "", nil, n.Opclass) + + case nodes.PartitionRangeDatum: + a.apply(&n, "", nil, n.Value) + + case nodes.PartitionSpec: + a.apply(&n, "", nil, n.PartParams) + + case nodes.PrepareStmt: + a.apply(&n, "", nil, n.Argtypes) + a.apply(&n, "", nil, n.Query) + + case nodes.Query: + a.apply(&n, "", nil, n.UtilityStmt) + a.apply(&n, "", nil, n.CteList) + a.apply(&n, "", nil, n.Jointree) + a.apply(&n, "", nil, n.TargetList) + a.apply(&n, "", nil, n.OnConflict) + a.apply(&n, "", nil, n.ReturningList) + a.apply(&n, "", nil, n.GroupClause) + a.apply(&n, "", nil, n.GroupingSets) + a.apply(&n, "", nil, n.HavingQual) + a.apply(&n, "", nil, n.WindowClause) + a.apply(&n, "", nil, n.DistinctClause) + a.apply(&n, "", nil, n.SortClause) + a.apply(&n, "", nil, n.LimitCount) + a.apply(&n, "", nil, n.RowMarks) + a.apply(&n, "", nil, n.SetOperations) + a.apply(&n, "", nil, n.ConstraintDeps) + a.apply(&n, "", nil, n.WithCheckOptions) + + case nodes.RangeFunction: + a.apply(&n, "", nil, n.Functions) + if n.Alias != nil { + a.apply(&n, "", nil, *n.Alias) + } + a.apply(&n, "", nil, n.Coldeflist) + + case nodes.RangeSubselect: + a.apply(&n, "", nil, n.Subquery) + if n.Alias != nil { + a.apply(&n, "", nil, *n.Alias) + } + + case nodes.RangeTableFunc: + a.apply(&n, "", nil, n.Docexpr) + a.apply(&n, "", nil, n.Rowexpr) + a.apply(&n, "", nil, n.Namespaces) + a.apply(&n, "", nil, n.Columns) + if n.Alias != nil { + a.apply(&n, "", nil, *n.Alias) + } + + case nodes.RangeTableFuncCol: + if n.TypeName != nil { + a.apply(&n, "", nil, *n.TypeName) + } + a.apply(&n, "", nil, n.Colexpr) + a.apply(&n, "", nil, n.Coldefexpr) + + case nodes.RangeTableSample: + a.apply(&n, "", nil, n.Relation) + a.apply(&n, "", nil, n.Method) + a.apply(&n, "", nil, n.Args) + + case nodes.RangeTblEntry: + a.apply(&n, "", nil, n.Tablesample) + a.apply(&n, "", nil, n.Subquery) + a.apply(&n, "", nil, n.Joinaliasvars) + a.apply(&n, "", nil, n.Functions) + a.apply(&n, "", nil, n.Tablefunc) + a.apply(&n, "", nil, n.ValuesLists) + a.apply(&n, "", nil, n.Coltypes) + a.apply(&n, "", nil, n.Colcollations) + if n.Alias != nil { + a.apply(&n, "", nil, *n.Alias) + } + a.apply(&n, "", nil, n.Eref) + a.apply(&n, "", nil, n.SecurityQuals) + + case nodes.RangeTblFunction: + a.apply(&n, "", nil, n.Funcexpr) + a.apply(&n, "", nil, n.Funccolnames) + a.apply(&n, "", nil, n.Funccoltypes) + a.apply(&n, "", nil, n.Funccoltypmods) + a.apply(&n, "", nil, n.Funccolcollations) + + case nodes.RangeTblRef: + // pass + + case nodes.RangeVar: + if n.Alias != nil { + a.apply(&n, "", nil, *n.Alias) + } + + case nodes.RawStmt: + a.apply(&n, "Stmt", nil, n.Stmt) + a.cursor.Replace(n) + + case nodes.ReassignOwnedStmt: + a.apply(&n, "", nil, n.Roles) + if n.Newrole != nil { + a.apply(&n, "", nil, *n.Newrole) + } + + case nodes.RefreshMatViewStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + + case nodes.ReindexStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + + case nodes.RelabelType: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Arg) + + case nodes.RenameStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.Object) + + case nodes.ReplicaIdentityStmt: + // pass + + case nodes.ResTarget: + a.apply(&n, "Indirection", nil, n.Indirection) + a.apply(&n, "Val", nil, n.Val) + a.cursor.Replace(n) + + case nodes.RoleSpec: + // pass + + case nodes.RowCompareExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Opnos) + a.apply(&n, "", nil, n.Opfamilies) + a.apply(&n, "", nil, n.Inputcollids) + a.apply(&n, "", nil, n.Largs) + a.apply(&n, "", nil, n.Rargs) + + case nodes.RowExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Args) + a.apply(&n, "", nil, n.Colnames) + + case nodes.RowMarkClause: + // pass + + case nodes.RuleStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.WhereClause) + a.apply(&n, "", nil, n.Actions) + + case nodes.SQLValueFunction: + a.apply(&n, "", nil, n.Xpr) + + case nodes.ScalarArrayOpExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Args) + + case nodes.SecLabelStmt: + a.apply(&n, "", nil, n.Object) + + case nodes.SelectStmt: + a.apply(&n, "DistinctClause", nil, n.DistinctClause) + if n.IntoClause != nil { + a.apply(&n, "IntoClause", nil, *n.IntoClause) + } + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "FromClause", nil, n.FromClause) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "GroupClause", nil, n.GroupClause) + a.apply(&n, "HavingClause", nil, n.HavingClause) + a.apply(&n, "WindowClause", nil, n.WindowClause) + // TODO: Not sure how to handle a slice of a slice + // + // for _, vs := range n.ValuesLists { + // for _, v := range vs { + // a.apply(&n, "", nil, v) + // } + // } + a.apply(&n, "SortClause", nil, n.SortClause) + a.apply(&n, "LimitOffset", nil, n.LimitOffset) + a.apply(&n, "LimitCount", nil, n.LimitCount) + a.apply(&n, "LockingClause", nil, n.LockingClause) + if n.WithClause != nil { + a.apply(&n, "WithClause", nil, *n.WithClause) + } + if n.Larg != nil { + a.apply(&n, "Larg", nil, *n.Larg) + } + if n.Rarg != nil { + a.apply(&n, "Rarg", nil, *n.Rarg) + } + a.cursor.Replace(n) + + case nodes.SetOperationStmt: + a.apply(&n, "Larg", nil, n.Larg) + a.apply(&n, "Rarg", nil, n.Rarg) + a.apply(&n, "ColTypes", nil, n.ColTypes) + a.apply(&n, "ColTypmods", nil, n.ColTypmods) + a.apply(&n, "ColCollations", nil, n.ColCollations) + a.apply(&n, "GroupClauses", nil, n.GroupClauses) + + case nodes.SetToDefault: + a.apply(&n, "Xpr", nil, n.Xpr) + + case nodes.SortBy: + a.apply(&n, "Node", nil, n.Node) + a.apply(&n, "UseOp", nil, n.UseOp) + + case nodes.SortGroupClause: + // pass + + case nodes.String: + // pass + + case nodes.SubLink: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Testexpr", nil, n.Testexpr) + a.apply(&n, "Opername", nil, n.OperName) + a.apply(&n, "Subselect", nil, n.Subselect) + + case nodes.SubPlan: + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Testexpr", nil, n.Testexpr) + a.apply(&n, "ParamIds", nil, n.ParamIds) + a.apply(&n, "SetParam", nil, n.SetParam) + a.apply(&n, "ParParam", nil, n.ParParam) + a.apply(&n, "Args", nil, n.Args) + + case nodes.TableFunc: + a.apply(&n, "NsUris", nil, n.NsUris) + a.apply(&n, "NsNames", nil, n.NsNames) + a.apply(&n, "Docexpr", nil, n.Docexpr) + a.apply(&n, "Rowexpr", nil, n.Rowexpr) + a.apply(&n, "Colnames", nil, n.Colnames) + a.apply(&n, "Coltypes", nil, n.Coltypes) + a.apply(&n, "ColTypmods", nil, n.Coltypmods) + a.apply(&n, "Colcollations", nil, n.Colcollations) + a.apply(&n, "Colexprs", nil, n.Colexprs) + a.apply(&n, "Coldefexprs", nil, n.Coldefexprs) + + case nodes.TableLikeClause: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + + case nodes.TableSampleClause: + a.apply(&n, "", nil, n.Args) + a.apply(&n, "", nil, n.Repeatable) + + case nodes.TargetEntry: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Expr) + + case nodes.TransactionStmt: + a.apply(&n, "", nil, n.Options) + + case nodes.TriggerTransition: + // pass + + case nodes.TruncateStmt: + a.apply(&n, "", nil, n.Relations) + + case nodes.TypeCast: + a.apply(&n, "", nil, n.Arg) + if n.TypeName != nil { + a.apply(&n, "", nil, *n.TypeName) + } + + case nodes.TypeName: + a.apply(&n, "", nil, n.Names) + a.apply(&n, "", nil, n.Typmods) + a.apply(&n, "", nil, n.ArrayBounds) + + case nodes.UnlistenStmt: + // pass + + case nodes.UpdateStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.TargetList) + a.apply(&n, "", nil, n.WhereClause) + a.apply(&n, "", nil, n.FromClause) + a.apply(&n, "", nil, n.ReturningList) + if n.WithClause != nil { + a.apply(&n, "", nil, *n.WithClause) + } + + case nodes.VacuumStmt: + if n.Relation != nil { + a.apply(&n, "", nil, *n.Relation) + } + a.apply(&n, "", nil, n.VaCols) + + case nodes.Var: + a.apply(&n, "", nil, n.Xpr) + + case nodes.VariableSetStmt: + a.apply(&n, "", nil, n.Args) + + case nodes.VariableShowStmt: + // pass + + case nodes.ViewStmt: + if n.View != nil { + a.apply(&n, "", nil, *n.View) + } + a.apply(&n, "", nil, n.Aliases) + a.apply(&n, "", nil, n.Query) + a.apply(&n, "", nil, n.Options) + + case nodes.WindowClause: + a.apply(&n, "", nil, n.PartitionClause) + a.apply(&n, "", nil, n.OrderClause) + a.apply(&n, "", nil, n.StartOffset) + a.apply(&n, "", nil, n.EndOffset) + + case nodes.WindowDef: + a.apply(&n, "", nil, n.PartitionClause) + a.apply(&n, "", nil, n.OrderClause) + a.apply(&n, "", nil, n.StartOffset) + a.apply(&n, "", nil, n.EndOffset) + + case nodes.WindowFunc: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.Args) + a.apply(&n, "", nil, n.Aggfilter) + + case nodes.WithCheckOption: + a.apply(&n, "", nil, n.Qual) + + case nodes.WithClause: + a.apply(&n, "", nil, n.Ctes) + + case nodes.XmlExpr: + a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "", nil, n.NamedArgs) + a.apply(&n, "", nil, n.ArgNames) + a.apply(&n, "", nil, n.Args) + + case nodes.XmlSerialize: + a.apply(&n, "", nil, n.Expr) + if n.TypeName != nil { + a.apply(&n, "", nil, *n.TypeName) + } + + default: + panic(fmt.Sprintf("Apply: unexpected node type %T", n)) + } + + if a.post != nil && !a.post(&a.cursor) { + panic(abort) + } + + a.cursor = saved +} + +// An iterator controls iteration over a slice of nodes. +type iterator struct { + index, step int +} + +func (a *application) applyList(parent nodes.Node, name string) { + // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead + saved := a.iter + a.iter.index = 0 + for { + // must reload parent.name each time, since cursor modifications might change it + v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) + if a.iter.index >= v.Len() { + break + } + + // element x may be nil in a bad AST - be cautious + var x nodes.Node + if e := v.Index(a.iter.index); e.IsValid() { + x = e.Interface().(nodes.Node) + } + + a.iter.step = 1 + a.apply(parent, name, &a.iter, x) + a.iter.index += a.iter.step + } + a.iter = saved +} diff --git a/internal/dinosql/astutil_test.go b/internal/postgresql/ast/astutil_test.go similarity index 82% rename from internal/dinosql/astutil_test.go rename to internal/postgresql/ast/astutil_test.go index 5e04807b2e..2f6f4273e3 100644 --- a/internal/dinosql/astutil_test.go +++ b/internal/postgresql/ast/astutil_test.go @@ -1,4 +1,4 @@ -package dinosql +package ast import ( "testing" @@ -18,18 +18,20 @@ func TestApply(t *testing.T) { t.Fatal(err) } - // spew.Dump(input.Statements[0]) - expect := output.Statements[0] actual := Apply(input.Statements[0], func(cr *Cursor) bool { fun, ok := cr.Node().(nodes.FuncCall) if !ok { return true } - if join(fun.Funcname, ".") == "sqlc.arg" { - cr.Replace(nodes.ParamRef{Number: 1}) + if Join(fun.Funcname, ".") == "sqlc.arg" { + cr.Replace(nodes.ParamRef{ + Number: 1, + Location: fun.Location, + }) return false } + return true }, nil) diff --git a/internal/postgresql/ast/join.go b/internal/postgresql/ast/join.go new file mode 100644 index 0000000000..343b58129c --- /dev/null +++ b/internal/postgresql/ast/join.go @@ -0,0 +1,17 @@ +package ast + +import ( + "strings" + + nodes "github.com/lfittl/pg_query_go/nodes" +) + +func Join(list nodes.List, sep string) string { + items := []string{} + for _, item := range list.Items { + if n, ok := item.(nodes.String); ok { + items = append(items, n.Str) + } + } + return strings.Join(items, sep) +} diff --git a/internal/dinosql/soup.go b/internal/postgresql/ast/soup.go similarity index 99% rename from internal/dinosql/soup.go rename to internal/postgresql/ast/soup.go index 17786526eb..a662e6f9a1 100644 --- a/internal/dinosql/soup.go +++ b/internal/postgresql/ast/soup.go @@ -1,4 +1,4 @@ -package dinosql +package ast import ( "fmt" From 1403d19ade187d43136e829f1584e49a290293bb Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Thu, 16 Jan 2020 21:42:27 -0800 Subject: [PATCH 04/11] Horrible work in progress --- internal/dinosql/parser.go | 57 ++++-- internal/dinosql/query_test.go | 5 +- internal/dinosql/rewrite.go | 63 +++++-- internal/postgresql/ast/astutil.go | 290 +++++++++++++++++++---------- 4 files changed, 283 insertions(+), 132 deletions(-) diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 2ddcfd6b70..cc57e997a9 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -460,15 +460,12 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) return nil, err } - // Query manipulation - raw, err = rewriteNamedParameters(raw) - if err != nil { - return nil, err - } + // Re-write query AST + raw, namedParams, edits := rewriteNamedParameters(raw) rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) - params, err := resolveCatalogRefs(c, rvs, refs) + params, err := resolveCatalogRefs(c, rvs, refs, namedParams) if err != nil { return nil, err } @@ -477,11 +474,24 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } - expanded, err := expand(c, raw, rawSQL) + expandEdits, err := expand(c, raw) + if err != nil { + return nil, err + } + + edits = append(edits, expandEdits...) + expanded, err := editQuery(rawSQL, edits) if err != nil { return nil, err } + // If the query string was edited, make sure the syntax is valid + if expanded != rawSQL { + if _, err := pg.Parse(expanded); err != nil { + return nil, fmt.Errorf("edited query syntax is invalid: %w", err) + } + } + trimmed, comments, err := stripComments(strings.TrimSpace(expanded)) if err != nil { return nil, err @@ -519,7 +529,7 @@ type edit struct { New string } -func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { +func expand(c core.Catalog, raw nodes.RawStmt) ([]edit, error) { list := search(raw, func(node nodes.Node) bool { switch node.(type) { case nodes.DeleteStmt: @@ -532,17 +542,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { return true }) if len(list.Items) == 0 { - return sql, nil + return nil, nil } var edits []edit for _, item := range list.Items { edit, err := expandStmt(c, raw, item) if err != nil { - return "", err + return nil, err } edits = append(edits, edit...) } - return editQuery(sql, edits) + return edits, nil } func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) { @@ -1141,12 +1151,19 @@ func search(root nodes.Node, f func(nodes.Node) bool) nodes.List { return ns.list } -func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ([]Parameter, error) { +func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef, names map[int]string) ([]Parameter, error) { aliasMap := map[string]core.FQN{} // TODO: Deprecate defaultTable var defaultTable *core.FQN var tables []core.FQN + parameterName := func(n int, defaultName string) string { + if n, ok := names[n]; ok { + return n + } + return defaultName + } + for _, rv := range rvs { if rv.Relname == nil { continue @@ -1196,7 +1213,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: "offset", + Name: parameterName(ref.ref.Number, "offset"), DataType: "integer", NotNull: true, }, @@ -1206,7 +1223,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: "limit", + Name: parameterName(ref.ref.Number, "limit"), DataType: "integer", NotNull: true, }, @@ -1265,7 +1282,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: key, + Name: parameterName(ref.ref.Number, key), DataType: c.DataType, NotNull: c.NotNull, IsArray: c.IsArray, @@ -1318,7 +1335,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: fun.Name, + Name: parameterName(ref.ref.Number, fun.Name), DataType: "any", }, }) @@ -1335,7 +1352,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: name, + Name: parameterName(ref.ref.Number, name), DataType: arg.DataType, NotNull: true, }, @@ -1351,7 +1368,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( a = append(a, Parameter{ Number: ref.ref.Number, Column: core.Column{ - Name: key, + Name: parameterName(ref.ref.Number, key), DataType: c.DataType, NotNull: c.NotNull, IsArray: c.IsArray, @@ -1370,9 +1387,11 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( if n.TypeName == nil { return nil, fmt.Errorf("nodes.TypeCast has nil type name") } + col := catalog.ToColumn(n.TypeName) + col.Name = parameterName(ref.ref.Number, col.Name) a = append(a, Parameter{ Number: ref.ref.Number, - Column: catalog.ToColumn(n.TypeName), + Column: col, }) case nodes.ParamRef: diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index 249c71c013..b763dd4c0c 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -888,15 +888,16 @@ func TestQueries(t *testing.T) { "named_parameter", ` CREATE TABLE foo (name text not null); - SELECT name FROM foo WHERE name = sqlc.arg(slug); + SELECT name FROM foo WHERE name = sqlc.arg(slug) AND sqlc.arg(filter)::bool; `, Query{ - SQL: "SELECT name FROM foo WHERE name = $1;", + SQL: "SELECT name FROM foo WHERE name = $1 AND $2::bool", Columns: []core.Column{ {Table: public("foo"), Name: "name", DataType: "text", NotNull: true}, }, Params: []Parameter{ {1, core.Column{Table: public("foo"), Name: "slug", DataType: "text", NotNull: true}}, + {2, core.Column{Name: "filter", DataType: "bool", NotNull: true}}, }, }, }, diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go index 092c507e58..4c88bfa426 100644 --- a/internal/dinosql/rewrite.go +++ b/internal/dinosql/rewrite.go @@ -1,11 +1,20 @@ package dinosql import ( + "fmt" + nodes "github.com/lfittl/pg_query_go/nodes" "github.com/kyleconroy/sqlc/internal/postgresql/ast" ) +// Given an AST node, return the string representation of names +func flatten(root nodes.Node) string { + sw := &stringWalker{} + ast.Walk(sw, root) + return sw.String +} + type stringWalker struct { String string } @@ -17,20 +26,46 @@ func (s *stringWalker) Visit(node nodes.Node) ast.Visitor { return s } -func flatten(root nodes.Node) string { - sw := &stringWalker{} - ast.Walk(sw, root) - return sw.String -} +func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []edit) { + args := map[string]int{} + argn := 0 + var edits []edit + node := ast.Apply(raw, func(cr *ast.Cursor) bool { + fun, ok := cr.Node().(nodes.FuncCall) + if !ok { + return true + } + if ast.Join(fun.Funcname, ".") == "sqlc.arg" { + param := flatten(fun.Args) + if num, ok := args[param]; ok { + cr.Replace(nodes.ParamRef{ + Number: num, + Location: fun.Location, + }) + } else { + argn += 1 + args[param] = argn + cr.Replace(nodes.ParamRef{ + Number: argn, + Location: fun.Location, + }) + } + + // TODO: This code assumes that sqlc.arg(name) is on a single line + edits = append(edits, edit{ + Location: fun.Location - raw.StmtLocation, + Old: fmt.Sprintf("sqlc.arg(%s)", param), + New: fmt.Sprintf("$%d", args[param]), + }) + + return false + } + return true + }, nil) -func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, error) { - named := search(raw, func(node nodes.Node) bool { - fun, ok := node.(nodes.FuncCall) - return ok && join(fun.Funcname, ".") == "sqlc.arg" - }) - for _, np := range named.Items { - fun := np.(nodes.FuncCall) - flatten(fun.Args) + named := map[int]string{} + for k, v := range args { + named[v] = k } - return raw, nil + return node.(nodes.RawStmt), named, edits } diff --git a/internal/postgresql/ast/astutil.go b/internal/postgresql/ast/astutil.go index 526bc557f2..9e19d4226a 100644 --- a/internal/postgresql/ast/astutil.go +++ b/internal/postgresql/ast/astutil.go @@ -8,7 +8,6 @@ import ( "fmt" "reflect" - "github.com/davecgh/go-spew/spew" nodes "github.com/lfittl/pg_query_go/nodes" ) @@ -161,16 +160,19 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.A_Indices: a.apply(&n, "Lidx", nil, n.Lidx) a.apply(&n, "Uidx", nil, n.Uidx) + a.cursor.Replace(n) case nodes.A_Indirection: a.apply(&n, "Arg", nil, n.Arg) a.apply(&n, "Indirection", nil, n.Indirection) + a.cursor.Replace(n) case nodes.A_Star: // pass case nodes.AccessPriv: a.apply(&n, "Cols", nil, n.Cols) + a.cursor.Replace(n) case nodes.Aggref: a.apply(&n, "Xpr", nil, n.Xpr) @@ -180,55 +182,68 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Aggorder", nil, n.Aggorder) a.apply(&n, "Aggdistinct", nil, n.Aggdistinct) a.apply(&n, "Aggfilter", nil, n.Aggfilter) + a.cursor.Replace(n) case nodes.Alias: a.apply(&n, "Colnames", nil, n.Colnames) + a.cursor.Replace(n) case nodes.AlterCollationStmt: a.apply(&n, "Collname", nil, n.Collname) + a.cursor.Replace(n) case nodes.AlterDatabaseSetStmt: if n.Setstmt != nil { a.apply(&n, "Setstmt", nil, *n.Setstmt) + a.cursor.Replace(n) } case nodes.AlterDatabaseStmt: a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterDefaultPrivilegesStmt: if n.Action != nil { a.apply(&n, "Action", nil, *n.Action) } a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterDomainStmt: a.apply(&n, "TypeName", nil, n.TypeName) a.apply(&n, "Def", nil, n.Def) + a.cursor.Replace(n) case nodes.AlterEnumStmt: a.apply(&n, "TypeName", nil, n.TypeName) + a.cursor.Replace(n) case nodes.AlterEventTrigStmt: // pass case nodes.AlterExtensionContentsStmt: a.apply(&n, "Object", nil, n.Object) + a.cursor.Replace(n) case nodes.AlterExtensionStmt: a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterFdwStmt: a.apply(&n, "FuncOptions", nil, n.FuncOptions) a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterForeignServerStmt: a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterFunctionStmt: if n.Func != nil { a.apply(&n, "Func", nil, n.Func) } a.apply(&n, "Actions", nil, n.Actions) + a.cursor.Replace(n) case nodes.AlterObjectDependsStmt: if n.Relation != nil { @@ -236,22 +251,26 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node } a.apply(&n, "Object", nil, n.Object) a.apply(&n, "Extname", nil, n.Extname) + a.cursor.Replace(n) case nodes.AlterObjectSchemaStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) } a.apply(&n, "Object", nil, n.Object) + a.cursor.Replace(n) case nodes.AlterOpFamilyStmt: a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) a.apply(&n, "Items", nil, n.Items) + a.cursor.Replace(n) case nodes.AlterOperatorStmt: if n.Opername != nil { a.apply(&n, "Opername", nil, *n.Opername) } a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterOwnerStmt: if n.Relation != nil { @@ -261,6 +280,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node if n.Newowner != nil { a.apply(&n, "Newowner", nil, *n.Newowner) } + a.cursor.Replace(n) case nodes.AlterPolicyStmt: if n.Table != nil { @@ -269,80 +289,97 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Roles", nil, n.Roles) a.apply(&n, "Qual", nil, n.Qual) a.apply(&n, "WithCheck", nil, n.WithCheck) + a.cursor.Replace(n) case nodes.AlterPublicationStmt: a.apply(&n, "Options", nil, n.Options) a.apply(&n, "Table", nil, n.Tables) + a.cursor.Replace(n) case nodes.AlterRoleSetStmt: if n.Role != nil { a.apply(&n, "Role", nil, *n.Role) } a.apply(&n, "Setstmt", nil, n.Setstmt) + a.cursor.Replace(n) case nodes.AlterRoleStmt: if n.Role != nil { a.apply(&n, "Role", nil, *n.Role) } a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterSeqStmt: if n.Sequence != nil { a.apply(&n, "Sequence", nil, *n.Sequence) } a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterSubscriptionStmt: a.apply(&n, "Publication", nil, n.Publication) a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterSystemStmt: a.apply(&n, "Setstmt", nil, n.Setstmt) + a.cursor.Replace(n) case nodes.AlterTSConfigurationStmt: a.apply(&n, "Cfgname", nil, n.Cfgname) a.apply(&n, "Tokentype", nil, n.Tokentype) a.apply(&n, "Dicts", nil, n.Dicts) + a.cursor.Replace(n) case nodes.AlterTSDictionaryStmt: a.apply(&n, "Dictname", nil, n.Dictname) a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterTableCmd: if n.Newowner != nil { a.apply(&n, "Newowner", nil, *n.Newowner) } a.apply(&n, "Def", nil, n.Def) + a.cursor.Replace(n) case nodes.AlterTableMoveAllStmt: a.apply(&n, "Roles", nil, n.Roles) + a.cursor.Replace(n) case nodes.AlterTableSpaceOptionsStmt: a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlterTableStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) } a.apply(&n, "Cmds", nil, n.Cmds) + a.cursor.Replace(n) case nodes.AlterUserMappingStmt: if n.User != nil { a.apply(&n, "User", nil, *n.User) } a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.AlternativeSubPlan: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Subplans", nil, n.Subplans) + a.cursor.Replace(n) case nodes.ArrayCoerceExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) + a.cursor.Replace(n) case nodes.ArrayExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Elements", nil, n.Elements) + a.cursor.Replace(n) case nodes.ArrayRef: a.apply(&n, "Xpr", nil, n.Xpr) @@ -350,6 +387,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Reflowerindexpr", nil, n.Reflowerindexpr) a.apply(&n, "Refexpr", nil, n.Refexpr) a.apply(&n, "Refassgnexpr", nil, n.Refassgnexpr) + a.cursor.Replace(n) case nodes.BitString: // pass @@ -360,24 +398,29 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.BoolExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) + a.cursor.Replace(n) case nodes.BooleanTest: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) + a.cursor.Replace(n) case nodes.CaseExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) a.apply(&n, "Args", nil, n.Args) a.apply(&n, "Defresult", nil, n.Defresult) + a.cursor.Replace(n) case nodes.CaseTestExpr: a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.Replace(n) case nodes.CaseWhen: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Expr", nil, n.Expr) a.apply(&n, "Result", nil, n.Result) + a.cursor.Replace(n) case nodes.CheckPointStmt: // pass @@ -388,30 +431,36 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.ClusterStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) + a.cursor.Replace(n) } case nodes.CoalesceExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) + a.cursor.Replace(n) case nodes.CoerceToDomain: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) + a.cursor.Replace(n) case nodes.CoerceToDomainValue: a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.Replace(n) case nodes.CoerceViaIO: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) + a.cursor.Replace(n) case nodes.CollateClause: a.apply(&n, "Arg", nil, n.Arg) a.apply(&n, "Collname", nil, n.Collname) - + a.cursor.Replace(n) case nodes.CollateExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) + a.cursor.Replace(n) case nodes.ColumnDef: if n.TypeName != nil { @@ -421,27 +470,33 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "CookedDefault", nil, n.CookedDefault) a.apply(&n, "Constraints", nil, n.Constraints) a.apply(&n, "Fdwoptions", nil, n.Fdwoptions) + a.cursor.Replace(n) case nodes.ColumnRef: a.apply(&n, "Fields", nil, n.Fields) + a.cursor.Replace(n) case nodes.CommentStmt: a.apply(&n, "Object", nil, n.Object) + a.cursor.Replace(n) case nodes.CommonTableExpr: a.apply(&n, "Aliascolnames", nil, n.Aliascolnames) a.apply(&n, "Ctequery", nil, n.Ctequery) a.apply(&n, "Ctecolnames", nil, n.Ctecolnames) a.apply(&n, "Ctecolcollations", nil, n.Ctecolcollations) + a.cursor.Replace(n) case nodes.CompositeTypeStmt: if n.Typevar != nil { a.apply(&n, "Typevar", nil, *n.Typevar) } a.apply(&n, "Coldeflist", nil, n.Coldeflist) + a.cursor.Replace(n) case nodes.Const: a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.Replace(n) case nodes.Constraint: a.apply(&n, "RawExpr", nil, n.RawExpr) @@ -455,13 +510,16 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "FkAttrs", nil, n.FkAttrs) a.apply(&n, "PkAttrs", nil, n.PkAttrs) a.apply(&n, "OldConpfeqop", nil, n.OldConpfeqop) + a.cursor.Replace(n) case nodes.ConstraintsSetStmt: a.apply(&n, "Constraints", nil, n.Constraints) + a.cursor.Replace(n) case nodes.ConvertRowtypeExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) + a.cursor.Replace(n) case nodes.CopyStmt: if n.Relation != nil { @@ -470,9 +528,11 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Query", nil, n.Query) a.apply(&n, "Attlist", nil, n.Attlist) a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.CreateAmStmt: a.apply(&n, "HandlerName", nil, n.HandlerName) + a.cursor.Replace(n) case nodes.CreateCastStmt: if n.Sourcetype != nil { @@ -482,10 +542,12 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Targettype", nil, *n.Targettype) } a.apply(&n, "Func", nil, n.Func) + a.cursor.Replace(n) case nodes.CreateConversionStmt: a.apply(&n, "ConversionName", nil, n.ConversionName) a.apply(&n, "Funcname", nil, n.FuncName) + a.cursor.Replace(n) case nodes.CreateDomainStmt: a.apply(&n, "Domainname", nil, n.Domainname) @@ -496,10 +558,12 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "CollClause", nil, *n.CollClause) } a.apply(&n, "Constraints", nil, n.Constraints) + a.cursor.Replace(n) case nodes.CreateEnumStmt: - a.apply(&n, "", nil, n.TypeName) - a.apply(&n, "", nil, n.Vals) + a.apply(&n, "TypeName", nil, n.TypeName) + a.apply(&n, "Vals", nil, n.Vals) + a.cursor.Replace(n) case nodes.CreateEventTrigStmt: a.apply(&n, "", nil, n.Whenclause) @@ -777,76 +841,85 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "", nil, n.Args) a.apply(&n, "", nil, n.Refs) a.apply(&n, "", nil, n.Cols) + a.cursor.Replace(n) case nodes.GroupingSet: - a.apply(&n, "", nil, n.Content) + a.apply(&n, "Content", nil, n.Content) + a.cursor.Replace(n) case nodes.ImportForeignSchemaStmt: - a.apply(&n, "", nil, n.TableList) - a.apply(&n, "", nil, n.Options) + a.apply(&n, "TableList", nil, n.TableList) + a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.IndexElem: - a.apply(&n, "", nil, n.Expr) - a.apply(&n, "", nil, n.Collation) - a.apply(&n, "", nil, n.Opclass) + a.apply(&n, "Expr", nil, n.Expr) + a.apply(&n, "Collation", nil, n.Collation) + a.apply(&n, "Opclass", nil, n.Opclass) + a.cursor.Replace(n) case nodes.IndexStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.IndexParams) - a.apply(&n, "", nil, n.Options) - a.apply(&n, "", nil, n.WhereClause) - a.apply(&n, "", nil, n.ExcludeOpNames) + a.apply(&n, "IndexParams", nil, n.IndexParams) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "ExcludeOpNames", nil, n.ExcludeOpNames) + a.cursor.Replace(n) case nodes.InferClause: - a.apply(&n, "", nil, n.IndexElems) - a.apply(&n, "", nil, n.WhereClause) + a.apply(&n, "IndexElems", nil, n.IndexElems) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.cursor.Replace(n) case nodes.InferenceElem: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Expr) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Expr", nil, n.Expr) + a.cursor.Replace(n) case nodes.InlineCodeBlock: // pass case nodes.InsertStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.Cols) - a.apply(&n, "", nil, n.SelectStmt) + a.apply(&n, "Cols", nil, n.Cols) + a.apply(&n, "SelectStmt", nil, n.SelectStmt) if n.OnConflictClause != nil { - a.apply(&n, "", nil, *n.OnConflictClause) + a.apply(&n, "OnConflictClause", nil, *n.OnConflictClause) } - a.apply(&n, "", nil, n.ReturningList) + a.apply(&n, "ReturningList", nil, n.ReturningList) if n.WithClause != nil { - a.apply(&n, "", nil, *n.WithClause) + a.apply(&n, "WithClause", nil, *n.WithClause) } + a.cursor.Replace(n) case nodes.Integer: // pass case nodes.IntoClause: if n.Rel != nil { - a.apply(&n, "", nil, *n.Rel) + a.apply(&n, "Rel", nil, *n.Rel) } - a.apply(&n, "", nil, n.ColNames) - a.apply(&n, "", nil, n.Options) - a.apply(&n, "", nil, n.ViewQuery) + a.apply(&n, "ColNames", nil, n.ColNames) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "ViewQuery", nil, n.ViewQuery) + a.cursor.Replace(n) case nodes.JoinExpr: - a.apply(&n, "", nil, n.Larg) - a.apply(&n, "", nil, n.Rarg) - a.apply(&n, "", nil, n.UsingClause) - a.apply(&n, "", nil, n.Quals) + a.apply(&n, "Larg", nil, n.Larg) + a.apply(&n, "Rarg", nil, n.Rarg) + a.apply(&n, "UsingClause", nil, n.UsingClause) + a.apply(&n, "Quals", nil, n.Quals) if n.Alias != nil { - a.apply(&n, "", nil, *n.Alias) + a.apply(&n, "Alias", nil, *n.Alias) } + a.cursor.Replace(n) case nodes.List: a.applyList(&n, "Items") - spew.Dump(a.cursor) a.cursor.Replace(n) case nodes.ListenStmt: @@ -856,24 +929,24 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node // pass case nodes.LockStmt: - a.apply(&n, "", nil, n.Relations) + a.apply(&n, "Relations", nil, n.Relations) case nodes.LockingClause: - a.apply(&n, "", nil, n.LockedRels) + a.apply(&n, "LockedRels", nil, n.LockedRels) case nodes.MinMaxExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) case nodes.MultiAssignRef: - a.apply(&n, "", nil, n.Source) + a.apply(&n, "Source", nil, n.Source) case nodes.NamedArgExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Arg) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Arg) case nodes.NextValueExpr: - a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "Xpr", nil, n.Xpr) case nodes.NotifyStmt: // pass @@ -882,19 +955,19 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node // pass case nodes.NullTest: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Arg) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) case nodes.ObjectWithArgs: - a.apply(&n, "", nil, n.Objname) - a.apply(&n, "", nil, n.Objargs) + a.apply(&n, "Objname", nil, n.Objname) + a.apply(&n, "Objargs", nil, n.Objargs) case nodes.OnConflictClause: if n.Infer != nil { - a.apply(&n, "", nil, *n.Infer) + a.apply(&n, "Infer", nil, *n.Infer) } - a.apply(&n, "", nil, n.TargetList) - a.apply(&n, "", nil, n.WhereClause) + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "WhereClause", nil, n.WhereClause) case nodes.OnConflictExpr: a.apply(&n, "", nil, n.ArbiterElems) @@ -1189,109 +1262,132 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.TableLikeClause: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) + a.cursor.Replace(n) } case nodes.TableSampleClause: - a.apply(&n, "", nil, n.Args) - a.apply(&n, "", nil, n.Repeatable) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Repeatable", nil, n.Repeatable) + a.cursor.Replace(n) case nodes.TargetEntry: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Expr) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Expr", nil, n.Expr) + a.cursor.Replace(n) case nodes.TransactionStmt: - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.TriggerTransition: // pass case nodes.TruncateStmt: - a.apply(&n, "", nil, n.Relations) + a.apply(&n, "Relations", nil, n.Relations) + a.cursor.Replace(n) case nodes.TypeCast: - a.apply(&n, "", nil, n.Arg) - if n.TypeName != nil { - a.apply(&n, "", nil, *n.TypeName) - } + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "TypeName", nil, n.TypeName) + a.cursor.Replace(n) case nodes.TypeName: - a.apply(&n, "", nil, n.Names) - a.apply(&n, "", nil, n.Typmods) - a.apply(&n, "", nil, n.ArrayBounds) + a.apply(&n, "Names", nil, n.Names) + a.apply(&n, "Typmods", nil, n.Typmods) + a.apply(&n, "ArrayBounds", nil, n.ArrayBounds) + a.cursor.Replace(n) + + case *nodes.TypeName: + a.apply(n, "Names", nil, n.Names) + a.apply(n, "Typmods", nil, n.Typmods) + a.apply(n, "ArrayBounds", nil, n.ArrayBounds) + a.cursor.Replace(n) case nodes.UnlistenStmt: // pass case nodes.UpdateStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.TargetList) - a.apply(&n, "", nil, n.WhereClause) - a.apply(&n, "", nil, n.FromClause) - a.apply(&n, "", nil, n.ReturningList) + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "FromClause", nil, n.FromClause) + a.apply(&n, "ReturningList", nil, n.ReturningList) if n.WithClause != nil { - a.apply(&n, "", nil, *n.WithClause) + a.apply(&n, "WithClause", nil, *n.WithClause) } + a.cursor.Replace(n) case nodes.VacuumStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.VaCols) + a.apply(&n, "VaCols", nil, n.VaCols) + a.cursor.Replace(n) case nodes.Var: - a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.Replace(n) case nodes.VariableSetStmt: - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Args", nil, n.Args) + a.cursor.Replace(n) case nodes.VariableShowStmt: // pass case nodes.ViewStmt: if n.View != nil { - a.apply(&n, "", nil, *n.View) + a.apply(&n, "View", nil, *n.View) } - a.apply(&n, "", nil, n.Aliases) - a.apply(&n, "", nil, n.Query) - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Aliases", nil, n.Aliases) + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Options", nil, n.Options) + a.cursor.Replace(n) case nodes.WindowClause: - a.apply(&n, "", nil, n.PartitionClause) - a.apply(&n, "", nil, n.OrderClause) - a.apply(&n, "", nil, n.StartOffset) - a.apply(&n, "", nil, n.EndOffset) + a.apply(&n, "PartitionClause", nil, n.PartitionClause) + a.apply(&n, "OrderClause", nil, n.OrderClause) + a.apply(&n, "StartOffset", nil, n.StartOffset) + a.apply(&n, "EndOffset", nil, n.EndOffset) + a.cursor.Replace(n) case nodes.WindowDef: - a.apply(&n, "", nil, n.PartitionClause) - a.apply(&n, "", nil, n.OrderClause) - a.apply(&n, "", nil, n.StartOffset) - a.apply(&n, "", nil, n.EndOffset) + a.apply(&n, "PartitionClause", nil, n.PartitionClause) + a.apply(&n, "OrderClause", nil, n.OrderClause) + a.apply(&n, "StartOffset", nil, n.StartOffset) + a.apply(&n, "EndOffset", nil, n.EndOffset) + a.cursor.Replace(n) case nodes.WindowFunc: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Args) - a.apply(&n, "", nil, n.Aggfilter) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Aggfilter", nil, n.Aggfilter) + a.cursor.Replace(n) case nodes.WithCheckOption: - a.apply(&n, "", nil, n.Qual) + a.apply(&n, "Qual", nil, n.Qual) + a.cursor.Replace(n) case nodes.WithClause: - a.apply(&n, "", nil, n.Ctes) + a.apply(&n, "Ctes", nil, n.Ctes) + a.cursor.Replace(n) case nodes.XmlExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.NamedArgs) - a.apply(&n, "", nil, n.ArgNames) - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "NamedArgs", nil, n.NamedArgs) + a.apply(&n, "ArgNames", nil, n.ArgNames) + a.apply(&n, "Args", nil, n.Args) + a.cursor.Replace(n) case nodes.XmlSerialize: - a.apply(&n, "", nil, n.Expr) + a.apply(&n, "Expr", nil, n.Expr) if n.TypeName != nil { - a.apply(&n, "", nil, *n.TypeName) + a.apply(&n, "TypeName", nil, *n.TypeName) } + a.cursor.Replace(n) default: panic(fmt.Sprintf("Apply: unexpected node type %T", n)) From e76ff5847eac2e3fb202e4662883f1af16f24e23 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Thu, 16 Jan 2020 21:52:41 -0800 Subject: [PATCH 05/11] Only rewrite queries with named parameters --- internal/dinosql/rewrite.go | 8 ++++ internal/postgresql/ast/astutil.go | 68 +++++++++++++++++++----------- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go index 4c88bfa426..91621cf2f0 100644 --- a/internal/dinosql/rewrite.go +++ b/internal/dinosql/rewrite.go @@ -27,6 +27,14 @@ func (s *stringWalker) Visit(node nodes.Node) ast.Visitor { } func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []edit) { + found := search(raw, func(node nodes.Node) bool { + fun, ok := node.(nodes.FuncCall) + return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg" + }) + if len(found.Items) == 0 { + return raw, map[int]string{}, nil + } + args := map[string]int{} argn := 0 var edits []edit diff --git a/internal/postgresql/ast/astutil.go b/internal/postgresql/ast/astutil.go index 9e19d4226a..d4431abf00 100644 --- a/internal/postgresql/ast/astutil.go +++ b/internal/postgresql/ast/astutil.go @@ -1103,7 +1103,8 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.RangeVar: if n.Alias != nil { - a.apply(&n, "", nil, *n.Alias) + a.apply(&n, "Alias", nil, *n.Alias) + a.cursor.Replace(n) } case nodes.RawStmt: @@ -1111,30 +1112,35 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.cursor.Replace(n) case nodes.ReassignOwnedStmt: - a.apply(&n, "", nil, n.Roles) + a.apply(&n, "Roles", nil, n.Roles) if n.Newrole != nil { - a.apply(&n, "", nil, *n.Newrole) + a.apply(&n, "Newrole", nil, *n.Newrole) } + a.cursor.Replace(n) case nodes.RefreshMatViewStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) + a.cursor.Replace(n) } case nodes.ReindexStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) + a.cursor.Replace(n) } case nodes.RelabelType: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Arg) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.Replace(n) case nodes.RenameStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.Object) + a.apply(&n, "Object", nil, n.Object) + a.cursor.Replace(n) case nodes.ReplicaIdentityStmt: // pass @@ -1148,37 +1154,43 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node // pass case nodes.RowCompareExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Opnos) - a.apply(&n, "", nil, n.Opfamilies) - a.apply(&n, "", nil, n.Inputcollids) - a.apply(&n, "", nil, n.Largs) - a.apply(&n, "", nil, n.Rargs) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Opnos", nil, n.Opnos) + a.apply(&n, "Opfamilies", nil, n.Opfamilies) + a.apply(&n, "Inputcollids", nil, n.Inputcollids) + a.apply(&n, "Largs", nil, n.Largs) + a.apply(&n, "Rargs", nil, n.Rargs) + a.cursor.Replace(n) case nodes.RowExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Args) - a.apply(&n, "", nil, n.Colnames) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Colnames", nil, n.Colnames) + a.cursor.Replace(n) case nodes.RowMarkClause: // pass case nodes.RuleStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.WhereClause) - a.apply(&n, "", nil, n.Actions) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "Actions", nil, n.Actions) + a.cursor.Replace(n) case nodes.SQLValueFunction: - a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.Replace(n) case nodes.ScalarArrayOpExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.Replace(n) case nodes.SecLabelStmt: - a.apply(&n, "", nil, n.Object) + a.apply(&n, "Object", nil, n.Object) + a.cursor.Replace(n) case nodes.SelectStmt: a.apply(&n, "DistinctClause", nil, n.DistinctClause) @@ -1220,13 +1232,16 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "ColTypmods", nil, n.ColTypmods) a.apply(&n, "ColCollations", nil, n.ColCollations) a.apply(&n, "GroupClauses", nil, n.GroupClauses) + a.cursor.Replace(n) case nodes.SetToDefault: a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.Replace(n) case nodes.SortBy: a.apply(&n, "Node", nil, n.Node) a.apply(&n, "UseOp", nil, n.UseOp) + a.cursor.Replace(n) case nodes.SortGroupClause: // pass @@ -1239,6 +1254,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Testexpr", nil, n.Testexpr) a.apply(&n, "Opername", nil, n.OperName) a.apply(&n, "Subselect", nil, n.Subselect) + a.cursor.Replace(n) case nodes.SubPlan: a.apply(&n, "Xpr", nil, n.Xpr) @@ -1247,6 +1263,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "SetParam", nil, n.SetParam) a.apply(&n, "ParParam", nil, n.ParParam) a.apply(&n, "Args", nil, n.Args) + a.cursor.Replace(n) case nodes.TableFunc: a.apply(&n, "NsUris", nil, n.NsUris) @@ -1259,6 +1276,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Colcollations", nil, n.Colcollations) a.apply(&n, "Colexprs", nil, n.Colexprs) a.apply(&n, "Coldefexprs", nil, n.Coldefexprs) + a.cursor.Replace(n) case nodes.TableLikeClause: if n.Relation != nil { From 57d614e7b98d588caccf26d7175ae6b62b77dbc6 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Thu, 16 Jan 2020 21:58:22 -0800 Subject: [PATCH 06/11] Remove unused packages --- internal/dinosql/parser.go | 5 - internal/mysql/example/db.go | 29 ------ internal/mysql/example/models.go | 37 ------- internal/mysql/example/queries.sql.go | 100 ------------------- internal/named/named.go | 133 -------------------------- internal/named/named_test.go | 121 ----------------------- 6 files changed, 425 deletions(-) delete mode 100644 internal/mysql/example/db.go delete mode 100644 internal/mysql/example/models.go delete mode 100644 internal/mysql/example/queries.sql.go delete mode 100644 internal/named/named.go delete mode 100644 internal/named/named_test.go diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index cc57e997a9..b3baf51ae3 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -232,11 +232,6 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) { continue } source := string(blob) - // source, _, err := named.CompileNamedQuery(blob, named.DOLLAR) - // if err != nil { - // merr.Add(filename, "", 0, err) - // continue - // } tree, err := pg.Parse(source) if err != nil { merr.Add(filename, source, 0, err) diff --git a/internal/mysql/example/db.go b/internal/mysql/example/db.go deleted file mode 100644 index 3df4e8545c..0000000000 --- a/internal/mysql/example/db.go +++ /dev/null @@ -1,29 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. - -package teachersDB - -import ( - "context" - "database/sql" -) - -type DBTX interface { - ExecContext(context.Context, string, ...interface{}) (sql.Result, error) - PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row -} - -func New(db DBTX) *Queries { - return &Queries{db: db} -} - -type Queries struct { - db DBTX -} - -func (q *Queries) WithTx(tx *sql.Tx) *Queries { - return &Queries{ - db: tx, - } -} diff --git a/internal/mysql/example/models.go b/internal/mysql/example/models.go deleted file mode 100644 index 7655ff580f..0000000000 --- a/internal/mysql/example/models.go +++ /dev/null @@ -1,37 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. - -package teachersDB - -import ( - "database/sql" -) - -type DepartmentType string - -const ( - English DepartmentType = "English" - Math DepartmentType = "Math" -) - -func (e *DepartmentType) Scan(src interface{}) error { - *e = DepartmentType(src.([]byte)) - return nil -} - -type Teacher struct { - ID int `json:"id"` - FirstName sql.NullString `json:"first_name"` - LastName sql.NullString `json:"last_name"` - SchoolID int `json:"school_id"` - ClassID int `json:"class_id"` - SchoolLat sql.NullFloat64 `json:"school_lat"` - SchoolLng sql.NullFloat64 `json:"school_lng"` - Department DepartmentType `json:"department"` -} - -type Student struct { - ID int `json:"id"` - ClassID int `json:"class_id"` - FirstName sql.NullString `json:"first_name"` - LastName sql.NullString `json:"last_name"` -} diff --git a/internal/mysql/example/queries.sql.go b/internal/mysql/example/queries.sql.go deleted file mode 100644 index 1d1a3570a6..0000000000 --- a/internal/mysql/example/queries.sql.go +++ /dev/null @@ -1,100 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// source: queries.sql - -package teachersDB - -import ( - "context" - "database/sql" -) - -const getSomeTeachers = `-- name: GetSomeTeachers :one -select school_id, id from teachers where school_lng > ? and school_lat < ? -` - -type GetSomeTeachersParams struct { - SchoolLng sql.NullFloat64 `json:"school_lng"` - SchoolLat sql.NullFloat64 `json:"school_lat"` -} - -type GetSomeTeachersRow struct { - SchoolID int `json:"school_id"` - ID int `json:"id"` -} - -func (q *Queries) GetSomeTeachers(ctx context.Context, arg GetSomeTeachersParams) (GetSomeTeachersRow, error) { - row := q.db.QueryRowContext(ctx, getSomeTeachers, arg.SchoolLng, arg.SchoolLat) - var i GetSomeTeachersRow - err := row.Scan(&i.SchoolID, &i.ID) - return i, err -} - -const getStudentsTeacher = `-- name: GetStudentsTeacher :one -select students.first_name, students.last_name, teachers.first_name as teacherFirstName, teachers.id as teacher_id from students left join teachers on teachers.class_id = students.class_id where students.id = ? -` - -type GetStudentsTeacherRow struct { - FirstName sql.NullString `json:"first_name"` - LastName sql.NullString `json:"last_name"` - TeacherFirstName sql.NullString `json:"teacherFirstName"` - TeacherID sql.NullInt64 `json:"teacher_id"` -} - -func (q *Queries) GetStudentsTeacher(ctx context.Context, studentID int) (GetStudentsTeacherRow, error) { - row := q.db.QueryRowContext(ctx, getStudentsTeacher, studentID) - var i GetStudentsTeacherRow - err := row.Scan( - &i.FirstName, - &i.LastName, - &i.TeacherFirstName, - &i.TeacherID, - ) - return i, err -} - -const getTeachersByID = `-- name: GetTeachersByID :one -select id, first_name, last_name, school_id, class_id, school_lat, school_lng, department from teachers where id = ? -` - -type GetTeachersByIDRow struct { - ID int `json:"id"` - FirstName sql.NullString `json:"first_name"` - LastName sql.NullString `json:"last_name"` - SchoolID int `json:"school_id"` - ClassID int `json:"class_id"` - SchoolLat sql.NullFloat64 `json:"school_lat"` - SchoolLng sql.NullFloat64 `json:"school_lng"` - Department DepartmentType `json:"department"` -} - -func (q *Queries) GetTeachersByID(ctx context.Context, id int) (GetTeachersByIDRow, error) { - row := q.db.QueryRowContext(ctx, getTeachersByID, id) - var i GetTeachersByIDRow - err := row.Scan( - &i.ID, - &i.FirstName, - &i.LastName, - &i.SchoolID, - &i.ClassID, - &i.SchoolLat, - &i.SchoolLng, - &i.Department, - ) - return i, err -} - -const teachersByID = `-- name: TeachersByID :one -select id, school_lat from teachers where id = ? limit 10 -` - -type TeachersByIDRow struct { - ID int `json:"id"` - SchoolLat sql.NullFloat64 `json:"school_lat"` -} - -func (q *Queries) TeachersByID(ctx context.Context, id int) (TeachersByIDRow, error) { - row := q.db.QueryRowContext(ctx, teachersByID, id) - var i TeachersByIDRow - err := row.Scan(&i.ID, &i.SchoolLat) - return i, err -} diff --git a/internal/named/named.go b/internal/named/named.go deleted file mode 100644 index 2cb7253220..0000000000 --- a/internal/named/named.go +++ /dev/null @@ -1,133 +0,0 @@ -package named - -// Copyright (c) 2013, Jason Moiron -// -// Permission is hereby granted, free of charge, to any person -// obtaining a copy of this software and associated documentation -// files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, -// copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -// OTHER DEALINGS IN THE SOFTWARE. - -import ( - "errors" - "strconv" - "unicode" -) - -// Bindvar types supported by Rebind, BindMap and BindStruct. -const ( - UNKNOWN = iota - QUESTION - DOLLAR - NAMED - AT -) - -// -- Compilation of Named Queries - -// Allow digits and letters in bind params; additionally runes are -// checked against underscores, meaning that bind params can have be -// alphanumeric with underscores. Mind the difference between unicode -// digits and numbers, where '5' is a digit but '五' is not. -var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} - -// FIXME: this function isn't safe for unicode named params, as a failing test -// can testify. This is not a regression but a failure of the original code -// as well. It should be modified to range over runes in a string rather than -// bytes, even though this is less convenient and slower. Hopefully the -// addition of the prepared NamedStmt (which will only do this once) will make -// up for the slightly slower ad-hoc NamedExec/NamedQuery. - -// compile a NamedQuery into an unbound query (using the '?' bindvar) and -// a list of names. -func CompileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { - names = make([]string, 0, 10) - rebound := make([]byte, 0, len(qs)) - - inName := false - last := len(qs) - 1 - currentVar := 1 - name := make([]byte, 0, 10) - - for i, b := range qs { - // a ':' while we're in a name is an error - if b == ':' { - // if this is the second ':' in a '::' escape sequence, append a ':' - if inName && i > 0 && qs[i-1] == ':' { - rebound = append(rebound, ':') - inName = false - continue - } else if inName { - err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) - return query, names, err - } - inName = true - name = []byte{} - } else if inName && i > 0 && b == '=' && len(name) == 0 { - rebound = append(rebound, ':', '=') - inName = false - continue - // if we're in a name, and this is an allowed character, continue - } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { - // append the byte to the name if we are in a name and not on the last byte - name = append(name, b) - // if we're in a name and it's not an allowed character, the name is done - } else if inName { - inName = false - // if this is the final byte of the string and it is part of the name, then - // make sure to add it to the name - if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { - name = append(name, b) - } - // add the string representation to the names list - names = append(names, string(name)) - // add a proper bindvar for the bindType - switch bindType { - // oracle only supports named type bind vars even for positional - case NAMED: - rebound = append(rebound, ':') - rebound = append(rebound, name...) - case QUESTION, UNKNOWN: - rebound = append(rebound, '?') - case DOLLAR: - rebound = append(rebound, '$') - for _, b := range strconv.Itoa(currentVar) { - rebound = append(rebound, byte(b)) - } - currentVar++ - case AT: - rebound = append(rebound, '@', 'p') - for _, b := range strconv.Itoa(currentVar) { - rebound = append(rebound, byte(b)) - } - currentVar++ - } - // add this byte to string unless it was not part of the name - if i != last { - rebound = append(rebound, b) - } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { - rebound = append(rebound, b) - } - } else { - // this is a normal byte and should just go onto the rebound query - rebound = append(rebound, b) - } - } - - return string(rebound), names, err -} diff --git a/internal/named/named_test.go b/internal/named/named_test.go deleted file mode 100644 index 9c0aaff715..0000000000 --- a/internal/named/named_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package named - -// Copyright (c) 2013, Jason Moiron -// -// Permission is hereby granted, free of charge, to any person -// obtaining a copy of this software and associated documentation -// files (the "Software"), to deal in the Software without -// restriction, including without limitation the rights to use, -// copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -// HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -// OTHER DEALINGS IN THE SOFTWARE. - -import ( - "testing" -) - -func TestCompileQuery(t *testing.T) { - table := []struct { - Q, R, D, T, N string - V []string - }{ - // basic test for named parameters, invalid char ',' terminating - { - Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, - R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, - D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, - T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`, - N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, - V: []string{"name", "age", "first", "last"}, - }, - // This query tests a named parameter ending the string as well as numbers - { - Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, - R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, - D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, - T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`, - N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, - V: []string{"name1", "name2"}, - }, - { - Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, - R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, - D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, - T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`, - N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, - V: []string{"name1", "name2"}, - }, - { - Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, - R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, - D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, - T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`, - N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, - V: []string{"first_name", "last_name"}, - }, - { - Q: `SELECT @name := "name", :age, :first, :last`, - R: `SELECT @name := "name", ?, ?, ?`, - D: `SELECT @name := "name", $1, $2, $3`, - N: `SELECT @name := "name", :age, :first, :last`, - T: `SELECT @name := "name", @p1, @p2, @p3`, - V: []string{"age", "first", "last"}, - }, - /* This unicode awareness test sadly fails, because of our byte-wise worldview. - * We could certainly iterate by Rune instead, though it's a great deal slower, - * it's probably the RightWay(tm) - { - Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`, - R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, - D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, - N: []string{"name", "age", "first", "last"}, - }, - */ - } - - for _, test := range table { - qr, names, err := CompileNamedQuery([]byte(test.Q), QUESTION) - if err != nil { - t.Error(err) - } - if qr != test.R { - t.Errorf("expected %s, got %s", test.R, qr) - } - if len(names) != len(test.V) { - t.Errorf("expected %#v, got %#v", test.V, names) - } else { - for i, name := range names { - if name != test.V[i] { - t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name) - } - } - } - qd, _, _ := CompileNamedQuery([]byte(test.Q), DOLLAR) - if qd != test.D { - t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) - } - - qt, _, _ := CompileNamedQuery([]byte(test.Q), AT) - if qt != test.T { - t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt) - } - - qq, _, _ := CompileNamedQuery([]byte(test.Q), NAMED) - if qq != test.N { - t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) - } - } -} From c4dca5894ef412e7ea245d01beef567d84c1da36 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Thu, 16 Jan 2020 22:04:16 -0800 Subject: [PATCH 07/11] Simple upkeep --- internal/dinosql/checks.go | 5 +---- internal/dinosql/rewrite.go | 25 ++++++++++++++----------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/dinosql/checks.go b/internal/dinosql/checks.go index 15f81c4dbd..7d654ecef3 100644 --- a/internal/dinosql/checks.go +++ b/internal/dinosql/checks.go @@ -130,10 +130,7 @@ func validateParamStyle(n nodes.Node) error { _, ok := node.(nodes.ParamRef) return ok }) - named := search(n, func(node nodes.Node) bool { - fun, ok := node.(nodes.FuncCall) - return ok && join(fun.Funcname, ".") == "sqlc.arg" - }) + named := search(n, isNamedParamFunc) if len(named.Items) > 0 && len(positional.Items) > 0 { return pg.Error{ Code: "", // TODO: Pick a new error code diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go index 91621cf2f0..fa363ec952 100644 --- a/internal/dinosql/rewrite.go +++ b/internal/dinosql/rewrite.go @@ -26,11 +26,18 @@ func (s *stringWalker) Visit(node nodes.Node) ast.Visitor { return s } +func isNamedParamFunc(node nodes.Node) bool { + fun, ok := node.(nodes.FuncCall) + return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg" +} + +func isNamedParamSign(node nodes.Node) bool { + fun, ok := node.(nodes.FuncCall) + return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg" +} + func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []edit) { - found := search(raw, func(node nodes.Node) bool { - fun, ok := node.(nodes.FuncCall) - return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg" - }) + found := search(raw, isNamedParamFunc) if len(found.Items) == 0 { return raw, map[int]string{}, nil } @@ -39,11 +46,9 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ argn := 0 var edits []edit node := ast.Apply(raw, func(cr *ast.Cursor) bool { - fun, ok := cr.Node().(nodes.FuncCall) - if !ok { - return true - } - if ast.Join(fun.Funcname, ".") == "sqlc.arg" { + node := cr.Node() + if isNamedParamFunc(node) { + fun := node.(nodes.FuncCall) param := flatten(fun.Args) if num, ok := args[param]; ok { cr.Replace(nodes.ParamRef{ @@ -58,14 +63,12 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ Location: fun.Location, }) } - // TODO: This code assumes that sqlc.arg(name) is on a single line edits = append(edits, edit{ Location: fun.Location - raw.StmtLocation, Old: fmt.Sprintf("sqlc.arg(%s)", param), New: fmt.Sprintf("$%d", args[param]), }) - return false } return true From 222ce91d1b95ae10fda13a539716b11720006e38 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 17 Jan 2020 13:00:25 -0800 Subject: [PATCH 08/11] Support pointers --- internal/postgresql/ast/astutil.go | 268 +++++++++++++++-------------- 1 file changed, 141 insertions(+), 127 deletions(-) diff --git a/internal/postgresql/ast/astutil.go b/internal/postgresql/ast/astutil.go index d4431abf00..645673353e 100644 --- a/internal/postgresql/ast/astutil.go +++ b/internal/postgresql/ast/astutil.go @@ -112,6 +112,20 @@ func (c *Cursor) Replace(n nodes.Node) { v.Set(reflect.ValueOf(n)) } +// Replace replaces the current Node with n. +// The replacement node is not walked by Apply. +func (c *Cursor) set(val nodes.Node, ptr nodes.Node) { + v := c.field() + if i := c.Index(); i >= 0 { + v = v.Index(i) + } + if v.Type().Kind() == reflect.Ptr { + v.Set(reflect.ValueOf(ptr)) + } else { + v.Set(reflect.ValueOf(val)) + } +} + // application carries all the shared data so we can pass it around cheaply. type application struct { pre, post ApplyFunc @@ -125,6 +139,9 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node node = nil } + // TODO: If node is a pointer, dereference it. This prevents us from having + // to have nil checks in the case statement + // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead saved := a.cursor a.cursor.parent = parent @@ -145,34 +162,34 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.A_ArrayExpr: a.apply(&n, "Elements", nil, n.Elements) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.A_Const: a.apply(&n, "Val", nil, n.Val) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.A_Expr: a.apply(&n, "Name", nil, n.Name) a.apply(&n, "Lexpr", nil, n.Lexpr) a.apply(&n, "Rexpr", nil, n.Rexpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.A_Indices: a.apply(&n, "Lidx", nil, n.Lidx) a.apply(&n, "Uidx", nil, n.Uidx) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.A_Indirection: a.apply(&n, "Arg", nil, n.Arg) a.apply(&n, "Indirection", nil, n.Indirection) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.A_Star: // pass case nodes.AccessPriv: a.apply(&n, "Cols", nil, n.Cols) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.Aggref: a.apply(&n, "Xpr", nil, n.Xpr) @@ -182,68 +199,69 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Aggorder", nil, n.Aggorder) a.apply(&n, "Aggdistinct", nil, n.Aggdistinct) a.apply(&n, "Aggfilter", nil, n.Aggfilter) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.Alias: a.apply(&n, "Colnames", nil, n.Colnames) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterCollationStmt: a.apply(&n, "Collname", nil, n.Collname) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterDatabaseSetStmt: if n.Setstmt != nil { a.apply(&n, "Setstmt", nil, *n.Setstmt) - a.cursor.Replace(n) + a.cursor.set(n, &n) } case nodes.AlterDatabaseStmt: a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterDefaultPrivilegesStmt: if n.Action != nil { a.apply(&n, "Action", nil, *n.Action) } a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + // TODOO: Take a pointer or not: a.cursor.set(n, &n, &n) + a.cursor.set(n, &n) case nodes.AlterDomainStmt: a.apply(&n, "TypeName", nil, n.TypeName) a.apply(&n, "Def", nil, n.Def) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterEnumStmt: a.apply(&n, "TypeName", nil, n.TypeName) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterEventTrigStmt: // pass case nodes.AlterExtensionContentsStmt: a.apply(&n, "Object", nil, n.Object) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterExtensionStmt: a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterFdwStmt: a.apply(&n, "FuncOptions", nil, n.FuncOptions) a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterForeignServerStmt: a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterFunctionStmt: if n.Func != nil { a.apply(&n, "Func", nil, n.Func) } a.apply(&n, "Actions", nil, n.Actions) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterObjectDependsStmt: if n.Relation != nil { @@ -251,26 +269,26 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node } a.apply(&n, "Object", nil, n.Object) a.apply(&n, "Extname", nil, n.Extname) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterObjectSchemaStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) } a.apply(&n, "Object", nil, n.Object) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterOpFamilyStmt: a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) a.apply(&n, "Items", nil, n.Items) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterOperatorStmt: if n.Opername != nil { a.apply(&n, "Opername", nil, *n.Opername) } a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterOwnerStmt: if n.Relation != nil { @@ -280,7 +298,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node if n.Newowner != nil { a.apply(&n, "Newowner", nil, *n.Newowner) } - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterPolicyStmt: if n.Table != nil { @@ -289,97 +307,97 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Roles", nil, n.Roles) a.apply(&n, "Qual", nil, n.Qual) a.apply(&n, "WithCheck", nil, n.WithCheck) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterPublicationStmt: a.apply(&n, "Options", nil, n.Options) a.apply(&n, "Table", nil, n.Tables) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterRoleSetStmt: if n.Role != nil { a.apply(&n, "Role", nil, *n.Role) } a.apply(&n, "Setstmt", nil, n.Setstmt) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterRoleStmt: if n.Role != nil { a.apply(&n, "Role", nil, *n.Role) } a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterSeqStmt: if n.Sequence != nil { a.apply(&n, "Sequence", nil, *n.Sequence) } a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterSubscriptionStmt: a.apply(&n, "Publication", nil, n.Publication) a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterSystemStmt: a.apply(&n, "Setstmt", nil, n.Setstmt) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterTSConfigurationStmt: a.apply(&n, "Cfgname", nil, n.Cfgname) a.apply(&n, "Tokentype", nil, n.Tokentype) a.apply(&n, "Dicts", nil, n.Dicts) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterTSDictionaryStmt: a.apply(&n, "Dictname", nil, n.Dictname) a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterTableCmd: if n.Newowner != nil { a.apply(&n, "Newowner", nil, *n.Newowner) } a.apply(&n, "Def", nil, n.Def) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterTableMoveAllStmt: a.apply(&n, "Roles", nil, n.Roles) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterTableSpaceOptionsStmt: a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterTableStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) } a.apply(&n, "Cmds", nil, n.Cmds) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlterUserMappingStmt: if n.User != nil { a.apply(&n, "User", nil, *n.User) } a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.AlternativeSubPlan: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Subplans", nil, n.Subplans) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ArrayCoerceExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ArrayExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Elements", nil, n.Elements) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ArrayRef: a.apply(&n, "Xpr", nil, n.Xpr) @@ -387,7 +405,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Reflowerindexpr", nil, n.Reflowerindexpr) a.apply(&n, "Refexpr", nil, n.Refexpr) a.apply(&n, "Refassgnexpr", nil, n.Refassgnexpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.BitString: // pass @@ -398,29 +416,29 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.BoolExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.BooleanTest: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CaseExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) a.apply(&n, "Args", nil, n.Args) a.apply(&n, "Defresult", nil, n.Defresult) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CaseTestExpr: a.apply(&n, "Xpr", nil, n.Xpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CaseWhen: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Expr", nil, n.Expr) a.apply(&n, "Result", nil, n.Result) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CheckPointStmt: // pass @@ -431,36 +449,36 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.ClusterStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) - a.cursor.Replace(n) + a.cursor.set(n, &n) } case nodes.CoalesceExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CoerceToDomain: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CoerceToDomainValue: a.apply(&n, "Xpr", nil, n.Xpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CoerceViaIO: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CollateClause: a.apply(&n, "Arg", nil, n.Arg) a.apply(&n, "Collname", nil, n.Collname) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CollateExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ColumnDef: if n.TypeName != nil { @@ -470,33 +488,33 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "CookedDefault", nil, n.CookedDefault) a.apply(&n, "Constraints", nil, n.Constraints) a.apply(&n, "Fdwoptions", nil, n.Fdwoptions) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ColumnRef: a.apply(&n, "Fields", nil, n.Fields) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CommentStmt: a.apply(&n, "Object", nil, n.Object) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CommonTableExpr: a.apply(&n, "Aliascolnames", nil, n.Aliascolnames) a.apply(&n, "Ctequery", nil, n.Ctequery) a.apply(&n, "Ctecolnames", nil, n.Ctecolnames) a.apply(&n, "Ctecolcollations", nil, n.Ctecolcollations) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CompositeTypeStmt: if n.Typevar != nil { a.apply(&n, "Typevar", nil, *n.Typevar) } a.apply(&n, "Coldeflist", nil, n.Coldeflist) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.Const: a.apply(&n, "Xpr", nil, n.Xpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.Constraint: a.apply(&n, "RawExpr", nil, n.RawExpr) @@ -510,16 +528,16 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "FkAttrs", nil, n.FkAttrs) a.apply(&n, "PkAttrs", nil, n.PkAttrs) a.apply(&n, "OldConpfeqop", nil, n.OldConpfeqop) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ConstraintsSetStmt: a.apply(&n, "Constraints", nil, n.Constraints) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ConvertRowtypeExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CopyStmt: if n.Relation != nil { @@ -528,11 +546,11 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Query", nil, n.Query) a.apply(&n, "Attlist", nil, n.Attlist) a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CreateAmStmt: a.apply(&n, "HandlerName", nil, n.HandlerName) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CreateCastStmt: if n.Sourcetype != nil { @@ -542,12 +560,12 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Targettype", nil, *n.Targettype) } a.apply(&n, "Func", nil, n.Func) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CreateConversionStmt: a.apply(&n, "ConversionName", nil, n.ConversionName) a.apply(&n, "Funcname", nil, n.FuncName) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CreateDomainStmt: a.apply(&n, "Domainname", nil, n.Domainname) @@ -558,12 +576,12 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "CollClause", nil, *n.CollClause) } a.apply(&n, "Constraints", nil, n.Constraints) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CreateEnumStmt: a.apply(&n, "TypeName", nil, n.TypeName) a.apply(&n, "Vals", nil, n.Vals) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.CreateEventTrigStmt: a.apply(&n, "", nil, n.Whenclause) @@ -841,22 +859,22 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "", nil, n.Args) a.apply(&n, "", nil, n.Refs) a.apply(&n, "", nil, n.Cols) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.GroupingSet: a.apply(&n, "Content", nil, n.Content) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ImportForeignSchemaStmt: a.apply(&n, "TableList", nil, n.TableList) a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.IndexElem: a.apply(&n, "Expr", nil, n.Expr) a.apply(&n, "Collation", nil, n.Collation) a.apply(&n, "Opclass", nil, n.Opclass) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.IndexStmt: if n.Relation != nil { @@ -866,17 +884,17 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Options", nil, n.Options) a.apply(&n, "WhereClause", nil, n.WhereClause) a.apply(&n, "ExcludeOpNames", nil, n.ExcludeOpNames) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.InferClause: a.apply(&n, "IndexElems", nil, n.IndexElems) a.apply(&n, "WhereClause", nil, n.WhereClause) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.InferenceElem: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Expr", nil, n.Expr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.InlineCodeBlock: // pass @@ -894,7 +912,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node if n.WithClause != nil { a.apply(&n, "WithClause", nil, *n.WithClause) } - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.Integer: // pass @@ -906,7 +924,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "ColNames", nil, n.ColNames) a.apply(&n, "Options", nil, n.Options) a.apply(&n, "ViewQuery", nil, n.ViewQuery) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.JoinExpr: a.apply(&n, "Larg", nil, n.Larg) @@ -916,11 +934,11 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node if n.Alias != nil { a.apply(&n, "Alias", nil, *n.Alias) } - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.List: a.applyList(&n, "Items") - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ListenStmt: // pass @@ -1104,43 +1122,43 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.RangeVar: if n.Alias != nil { a.apply(&n, "Alias", nil, *n.Alias) - a.cursor.Replace(n) + a.cursor.set(n, &n) } case nodes.RawStmt: a.apply(&n, "Stmt", nil, n.Stmt) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ReassignOwnedStmt: a.apply(&n, "Roles", nil, n.Roles) if n.Newrole != nil { a.apply(&n, "Newrole", nil, *n.Newrole) } - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.RefreshMatViewStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) - a.cursor.Replace(n) + a.cursor.set(n, &n) } case nodes.ReindexStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) - a.cursor.Replace(n) + a.cursor.set(n, &n) } case nodes.RelabelType: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.RenameStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) } a.apply(&n, "Object", nil, n.Object) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ReplicaIdentityStmt: // pass @@ -1148,7 +1166,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.ResTarget: a.apply(&n, "Indirection", nil, n.Indirection) a.apply(&n, "Val", nil, n.Val) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.RoleSpec: // pass @@ -1160,13 +1178,13 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Inputcollids", nil, n.Inputcollids) a.apply(&n, "Largs", nil, n.Largs) a.apply(&n, "Rargs", nil, n.Rargs) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.RowExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) a.apply(&n, "Colnames", nil, n.Colnames) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.RowMarkClause: // pass @@ -1177,20 +1195,20 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node } a.apply(&n, "WhereClause", nil, n.WhereClause) a.apply(&n, "Actions", nil, n.Actions) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SQLValueFunction: a.apply(&n, "Xpr", nil, n.Xpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.ScalarArrayOpExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SecLabelStmt: a.apply(&n, "Object", nil, n.Object) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SelectStmt: a.apply(&n, "DistinctClause", nil, n.DistinctClause) @@ -1223,7 +1241,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node if n.Rarg != nil { a.apply(&n, "Rarg", nil, *n.Rarg) } - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SetOperationStmt: a.apply(&n, "Larg", nil, n.Larg) @@ -1232,16 +1250,16 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "ColTypmods", nil, n.ColTypmods) a.apply(&n, "ColCollations", nil, n.ColCollations) a.apply(&n, "GroupClauses", nil, n.GroupClauses) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SetToDefault: a.apply(&n, "Xpr", nil, n.Xpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SortBy: a.apply(&n, "Node", nil, n.Node) a.apply(&n, "UseOp", nil, n.UseOp) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SortGroupClause: // pass @@ -1254,7 +1272,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Testexpr", nil, n.Testexpr) a.apply(&n, "Opername", nil, n.OperName) a.apply(&n, "Subselect", nil, n.Subselect) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.SubPlan: a.apply(&n, "Xpr", nil, n.Xpr) @@ -1263,7 +1281,7 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "SetParam", nil, n.SetParam) a.apply(&n, "ParParam", nil, n.ParParam) a.apply(&n, "Args", nil, n.Args) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.TableFunc: a.apply(&n, "NsUris", nil, n.NsUris) @@ -1276,51 +1294,47 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Colcollations", nil, n.Colcollations) a.apply(&n, "Colexprs", nil, n.Colexprs) a.apply(&n, "Coldefexprs", nil, n.Coldefexprs) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.TableLikeClause: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) - a.cursor.Replace(n) + a.cursor.set(n, &n) } case nodes.TableSampleClause: a.apply(&n, "Args", nil, n.Args) a.apply(&n, "Repeatable", nil, n.Repeatable) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.TargetEntry: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Expr", nil, n.Expr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.TransactionStmt: a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.TriggerTransition: // pass case nodes.TruncateStmt: a.apply(&n, "Relations", nil, n.Relations) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.TypeCast: a.apply(&n, "Arg", nil, n.Arg) - a.apply(&n, "TypeName", nil, n.TypeName) - a.cursor.Replace(n) + if n.TypeName != nil { + a.apply(&n, "TypeName", nil, *n.TypeName) + } + a.cursor.set(n, &n) case nodes.TypeName: a.apply(&n, "Names", nil, n.Names) a.apply(&n, "Typmods", nil, n.Typmods) a.apply(&n, "ArrayBounds", nil, n.ArrayBounds) - a.cursor.Replace(n) - - case *nodes.TypeName: - a.apply(n, "Names", nil, n.Names) - a.apply(n, "Typmods", nil, n.Typmods) - a.apply(n, "ArrayBounds", nil, n.ArrayBounds) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.UnlistenStmt: // pass @@ -1336,22 +1350,22 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node if n.WithClause != nil { a.apply(&n, "WithClause", nil, *n.WithClause) } - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.VacuumStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) } a.apply(&n, "VaCols", nil, n.VaCols) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.Var: a.apply(&n, "Xpr", nil, n.Xpr) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.VariableSetStmt: a.apply(&n, "Args", nil, n.Args) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.VariableShowStmt: // pass @@ -1363,49 +1377,49 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Aliases", nil, n.Aliases) a.apply(&n, "Query", nil, n.Query) a.apply(&n, "Options", nil, n.Options) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.WindowClause: a.apply(&n, "PartitionClause", nil, n.PartitionClause) a.apply(&n, "OrderClause", nil, n.OrderClause) a.apply(&n, "StartOffset", nil, n.StartOffset) a.apply(&n, "EndOffset", nil, n.EndOffset) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.WindowDef: a.apply(&n, "PartitionClause", nil, n.PartitionClause) a.apply(&n, "OrderClause", nil, n.OrderClause) a.apply(&n, "StartOffset", nil, n.StartOffset) a.apply(&n, "EndOffset", nil, n.EndOffset) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.WindowFunc: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) a.apply(&n, "Aggfilter", nil, n.Aggfilter) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.WithCheckOption: a.apply(&n, "Qual", nil, n.Qual) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.WithClause: a.apply(&n, "Ctes", nil, n.Ctes) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.XmlExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "NamedArgs", nil, n.NamedArgs) a.apply(&n, "ArgNames", nil, n.ArgNames) a.apply(&n, "Args", nil, n.Args) - a.cursor.Replace(n) + a.cursor.set(n, &n) case nodes.XmlSerialize: a.apply(&n, "Expr", nil, n.Expr) if n.TypeName != nil { a.apply(&n, "TypeName", nil, *n.TypeName) } - a.cursor.Replace(n) + a.cursor.set(n, &n) default: panic(fmt.Sprintf("Apply: unexpected node type %T", n)) From 3fbcc63627e4e77b9710ca7df5165bda27605a8a Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 17 Jan 2020 13:23:31 -0800 Subject: [PATCH 09/11] More names --- internal/postgresql/ast/astutil.go | 280 +++++++++++++++++------------ 1 file changed, 162 insertions(+), 118 deletions(-) diff --git a/internal/postgresql/ast/astutil.go b/internal/postgresql/ast/astutil.go index 645673353e..aed2bf54c1 100644 --- a/internal/postgresql/ast/astutil.go +++ b/internal/postgresql/ast/astutil.go @@ -212,8 +212,8 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.AlterDatabaseSetStmt: if n.Setstmt != nil { a.apply(&n, "Setstmt", nil, *n.Setstmt) - a.cursor.set(n, &n) } + a.cursor.set(n, &n) case nodes.AlterDatabaseStmt: a.apply(&n, "Options", nil, n.Options) @@ -224,7 +224,6 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.apply(&n, "Action", nil, *n.Action) } a.apply(&n, "Options", nil, n.Options) - // TODOO: Take a pointer or not: a.cursor.set(n, &n, &n) a.cursor.set(n, &n) case nodes.AlterDomainStmt: @@ -449,8 +448,8 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.ClusterStmt: if n.Relation != nil { a.apply(&n, "Relation", nil, *n.Relation) - a.cursor.set(n, &n) } + a.cursor.set(n, &n) case nodes.CoalesceExpr: a.apply(&n, "Xpr", nil, n.Xpr) @@ -584,200 +583,234 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.cursor.set(n, &n) case nodes.CreateEventTrigStmt: - a.apply(&n, "", nil, n.Whenclause) - a.apply(&n, "", nil, n.Funcname) + a.apply(&n, "Whenclause", nil, n.Whenclause) + a.apply(&n, "Funcname", nil, n.Funcname) + a.cursor.set(n, &n) case nodes.CreateExtensionStmt: - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateFdwStmt: - a.apply(&n, "", nil, n.FuncOptions) - a.apply(&n, "", nil, n.Options) + a.apply(&n, "FuncOptions", nil, n.FuncOptions) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateForeignServerStmt: - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateForeignTableStmt: - a.apply(&n, "", nil, n.Base) - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Base", nil, n.Base) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateFunctionStmt: - a.apply(&n, "", nil, n.Funcname) - a.apply(&n, "", nil, n.Parameters) + a.apply(&n, "Funcname", nil, n.Funcname) + a.apply(&n, "Parameters", nil, n.Parameters) if n.ReturnType != nil { - a.apply(&n, "", nil, *n.ReturnType) + a.apply(&n, "ReturnType", nil, *n.ReturnType) } - a.apply(&n, "", nil, n.Options) - a.apply(&n, "", nil, n.WithClause) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "WithClause", nil, n.WithClause) + a.cursor.set(n, &n) case nodes.CreateOpClassItem: - a.apply(&n, "", nil, n.Name) - a.apply(&n, "", nil, n.OrderFamily) - a.apply(&n, "", nil, n.ClassArgs) + a.apply(&n, "Name", nil, n.Name) + a.apply(&n, "OrderFamily", nil, n.OrderFamily) + a.apply(&n, "ClassArgs", nil, n.ClassArgs) if n.Storedtype != nil { - a.apply(&n, "", nil, *n.Storedtype) + a.apply(&n, "Storedtype", nil, *n.Storedtype) } + a.cursor.set(n, &n) case nodes.CreateOpClassStmt: - a.apply(&n, "", nil, n.Opclassname) - a.apply(&n, "", nil, n.Opfamilyname) + a.apply(&n, "Opclassname", nil, n.Opclassname) + a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) if n.Datatype != nil { - a.apply(&n, "", nil, *n.Datatype) + a.apply(&n, "Datatype", nil, *n.Datatype) } - a.apply(&n, "", nil, n.Items) + a.apply(&n, "Items", nil, n.Items) + a.cursor.set(n, &n) case nodes.CreateOpFamilyStmt: - a.apply(&n, "", nil, n.Opfamilyname) + a.apply(&n, "Opfamilyname", nil, n.Opfamilyname) + a.cursor.set(n, &n) case nodes.CreatePLangStmt: - a.apply(&n, "", nil, n.Plhandler) - a.apply(&n, "", nil, n.Plinline) - a.apply(&n, "", nil, n.Plvalidator) + a.apply(&n, "Plhandler", nil, n.Plhandler) + a.apply(&n, "Plinline", nil, n.Plinline) + a.apply(&n, "Plvalidator", nil, n.Plvalidator) + a.cursor.set(n, &n) case nodes.CreatePolicyStmt: if n.Table != nil { - a.apply(&n, "", nil, *n.Table) + a.apply(&n, "Table", nil, *n.Table) } - a.apply(&n, "", nil, n.Roles) - a.apply(&n, "", nil, n.Qual) - a.apply(&n, "", nil, n.WithCheck) + a.apply(&n, "Roles", nil, n.Roles) + a.apply(&n, "Qual", nil, n.Qual) + a.apply(&n, "WithCheck", nil, n.WithCheck) + a.cursor.set(n, &n) case nodes.CreatePublicationStmt: - a.apply(&n, "", nil, n.Options) - a.apply(&n, "", nil, n.Tables) + a.apply(&n, "Options", nil, n.Options) + a.apply(&n, "Tables", nil, n.Tables) + a.cursor.set(n, &n) case nodes.CreateRangeStmt: - a.apply(&n, "", nil, n.TypeName) - a.apply(&n, "", nil, n.Params) + a.apply(&n, "TypeName", nil, n.TypeName) + a.apply(&n, "Params", nil, n.Params) + a.cursor.set(n, &n) case nodes.CreateRoleStmt: - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateSchemaStmt: if n.Authrole != nil { - a.apply(&n, "", nil, *n.Authrole) + a.apply(&n, "Authrole", nil, *n.Authrole) } - a.apply(&n, "", nil, n.SchemaElts) + a.apply(&n, "SchemaElts", nil, n.SchemaElts) + a.cursor.set(n, &n) case nodes.CreateSeqStmt: if n.Sequence != nil { - a.apply(&n, "", nil, *n.Sequence) + a.apply(&n, "Sequence", nil, *n.Sequence) } - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateStatsStmt: - a.apply(&n, "", nil, n.Defnames) - a.apply(&n, "", nil, n.StatTypes) - a.apply(&n, "", nil, n.Exprs) - a.apply(&n, "", nil, n.Relations) + a.apply(&n, "Defnames", nil, n.Defnames) + a.apply(&n, "StatTypes", nil, n.StatTypes) + a.apply(&n, "Exprs", nil, n.Exprs) + a.apply(&n, "Relations", nil, n.Relations) + a.cursor.set(n, &n) case nodes.CreateStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.TableElts) - a.apply(&n, "", nil, n.InhRelations) + a.apply(&n, "TableElts", nil, n.TableElts) + a.apply(&n, "InhRelations", nil, n.InhRelations) if n.Partbound != nil { - a.apply(&n, "", nil, *n.Partbound) + a.apply(&n, "Partbound", nil, *n.Partbound) } if n.Partspec != nil { - a.apply(&n, "", nil, *n.Partspec) + a.apply(&n, "Partspec", nil, *n.Partspec) } - a.apply(&n, "", nil, n.Constraints) - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Constraints", nil, n.Constraints) + a.apply(&n, "Options", nil, n.Options) if n.OfTypename != nil { - a.apply(&n, "", nil, *n.OfTypename) + a.apply(&n, "OfTypename", nil, *n.OfTypename) } + a.cursor.set(n, &n) case nodes.CreateSubscriptionStmt: - a.apply(&n, "", nil, n.Publication) - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Publication", nil, n.Publication) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateTableAsStmt: - a.apply(&n, "", nil, n.Query) - a.apply(&n, "", nil, n.Into) + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Into", nil, n.Into) + a.cursor.set(n, &n) case nodes.CreateTableSpaceStmt: if n.Owner != nil { - a.apply(&n, "", nil, *n.Owner) + a.apply(&n, "Owner", nil, *n.Owner) } - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreateTransformStmt: if n.TypeName != nil { - a.apply(&n, "", nil, *n.TypeName) + a.apply(&n, "TypeName", nil, *n.TypeName) } if n.Fromsql != nil { - a.apply(&n, "", nil, *n.Fromsql) + a.apply(&n, "Fromsql", nil, *n.Fromsql) } if n.Tosql != nil { - a.apply(&n, "", nil, *n.Tosql) + a.apply(&n, "Tosql", nil, *n.Tosql) } + a.cursor.set(n, &n) case nodes.CreateTrigStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.Funcname) - a.apply(&n, "", nil, n.Args) - a.apply(&n, "", nil, n.Columns) - a.apply(&n, "", nil, n.WhenClause) - a.apply(&n, "", nil, n.TransitionRels) + a.apply(&n, "Funcname", nil, n.Funcname) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Columns", nil, n.Columns) + a.apply(&n, "WhenClause", nil, n.WhenClause) + a.apply(&n, "TransitionRels", nil, n.TransitionRels) if n.Constrrel != nil { - a.apply(&n, "", nil, *n.Constrrel) + a.apply(&n, "Constrrel", nil, *n.Constrrel) } + a.cursor.set(n, &n) case nodes.CreateUserMappingStmt: if n.User != nil { - a.apply(&n, "", nil, *n.User) + a.apply(&n, "User", nil, *n.User) } - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CreatedbStmt: - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.CurrentOfExpr: - a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) case nodes.DeallocateStmt: // pass case nodes.DeclareCursorStmt: - a.apply(&n, "", nil, n.Query) + a.apply(&n, "Query", nil, n.Query) + a.cursor.set(n, &n) case nodes.DefElem: - a.apply(&n, "", nil, n.Arg) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) case nodes.DefineStmt: - a.apply(&n, "", nil, n.Defnames) - a.apply(&n, "", nil, n.Args) - a.apply(&n, "", nil, n.Definition) + a.apply(&n, "Defnames", nil, n.Defnames) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Definition", nil, n.Definition) + a.cursor.set(n, &n) case nodes.DeleteStmt: if n.Relation != nil { - a.apply(&n, "", nil, *n.Relation) + a.apply(&n, "Relation", nil, *n.Relation) } - a.apply(&n, "", nil, n.UsingClause) - a.apply(&n, "", nil, n.WhereClause) - a.apply(&n, "", nil, n.ReturningList) + a.apply(&n, "UsingClause", nil, n.UsingClause) + a.apply(&n, "WhereClause", nil, n.WhereClause) + a.apply(&n, "ReturningList", nil, n.ReturningList) if n.WithClause != nil { - a.apply(&n, "", nil, *n.WithClause) + a.apply(&n, "WithClause", nil, *n.WithClause) } + a.cursor.set(n, &n) case nodes.DiscardStmt: // pass case nodes.DoStmt: - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) case nodes.DropOwnedStmt: - a.apply(&n, "", nil, n.Roles) + a.apply(&n, "Roles", nil, n.Roles) + a.cursor.set(n, &n) case nodes.DropRoleStmt: - a.apply(&n, "", nil, n.Roles) + a.apply(&n, "Roles", nil, n.Roles) + a.cursor.set(n, &n) case nodes.DropStmt: - a.apply(&n, "", nil, n.Objects) + a.apply(&n, "Objects", nil, n.Objects) + a.cursor.set(n, &n) case nodes.DropSubscriptionStmt: // pass @@ -787,18 +820,21 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.DropUserMappingStmt: if n.User != nil { - a.apply(&n, "", nil, *n.User) + a.apply(&n, "User", nil, *n.User) } + a.cursor.set(n, &n) case nodes.DropdbStmt: // pass case nodes.ExecuteStmt: - a.apply(&n, "", nil, n.Params) + a.apply(&n, "Params", nil, n.Params) + a.cursor.set(n, &n) case nodes.ExplainStmt: - a.apply(&n, "", nil, n.Query) - a.apply(&n, "", nil, n.Options) + a.apply(&n, "Query", nil, n.Query) + a.apply(&n, "Options", nil, n.Options) + a.cursor.set(n, &n) case nodes.Expr: // pass @@ -807,58 +843,66 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node // pass case nodes.FieldSelect: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Arg) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) case nodes.FieldStore: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Arg) - a.apply(&n, "", nil, n.Newvals) - a.apply(&n, "", nil, n.Fieldnums) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Arg", nil, n.Arg) + a.apply(&n, "Newvals", nil, n.Newvals) + a.apply(&n, "Fieldnums", nil, n.Fieldnums) + a.cursor.set(n, &n) case nodes.Float: // pass case nodes.FromExpr: - a.apply(&n, "", nil, n.Fromlist) - a.apply(&n, "", nil, n.Quals) + a.apply(&n, "Fromlist", nil, n.Fromlist) + a.apply(&n, "Quals", nil, n.Quals) + a.cursor.set(n, &n) case nodes.FuncCall: - a.apply(&n, "", nil, n.Funcname) - a.apply(&n, "", nil, n.Args) - a.apply(&n, "", nil, n.AggOrder) - a.apply(&n, "", nil, n.AggFilter) + a.apply(&n, "Funcname", nil, n.Funcname) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "AggOrder", nil, n.AggOrder) + a.apply(&n, "AggFilter", nil, n.AggFilter) if n.Over != nil { - a.apply(&n, "", nil, *n.Over) + a.apply(&n, "Over", nil, *n.Over) } + a.cursor.set(n, &n) case nodes.FuncExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) case nodes.FunctionParameter: if n.ArgType != nil { - a.apply(&n, "", nil, *n.ArgType) + a.apply(&n, "ArgType", nil, *n.ArgType) } - a.apply(&n, "", nil, n.Defexpr) + a.apply(&n, "Defexpr", nil, n.Defexpr) + a.cursor.set(n, &n) case nodes.GrantRoleStmt: - a.apply(&n, "", nil, n.GrantedRoles) - a.apply(&n, "", nil, n.GranteeRoles) + a.apply(&n, "GrantedRoles", nil, n.GrantedRoles) + a.apply(&n, "GranteeRoles", nil, n.GranteeRoles) if n.Grantor != nil { - a.apply(&n, "", nil, *n.Grantor) + a.apply(&n, "Grantor", nil, *n.Grantor) } + a.cursor.set(n, &n) case nodes.GrantStmt: - a.apply(&n, "", nil, n.Objects) - a.apply(&n, "", nil, n.Privileges) - a.apply(&n, "", nil, n.Grantees) + a.apply(&n, "Objects", nil, n.Objects) + a.apply(&n, "Privileges", nil, n.Privileges) + a.apply(&n, "Grantees", nil, n.Grantees) + a.cursor.set(n, &n) case nodes.GroupingFunc: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Args) - a.apply(&n, "", nil, n.Refs) - a.apply(&n, "", nil, n.Cols) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.apply(&n, "Refs", nil, n.Refs) + a.apply(&n, "Cols", nil, n.Cols) a.cursor.set(n, &n) case nodes.GroupingSet: From 12728a6fe2e4f7065b76165e02474fd119706a19 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 17 Jan 2020 17:43:31 -0800 Subject: [PATCH 10/11] Finish astutil --- internal/postgresql/ast/astutil.go | 168 +++++++++++++++++------------ 1 file changed, 97 insertions(+), 71 deletions(-) diff --git a/internal/postgresql/ast/astutil.go b/internal/postgresql/ast/astutil.go index aed2bf54c1..97eade91e8 100644 --- a/internal/postgresql/ast/astutil.go +++ b/internal/postgresql/ast/astutil.go @@ -981,8 +981,8 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node a.cursor.set(n, &n) case nodes.List: + // Since item is a slice a.applyList(&n, "Items") - a.cursor.set(n, &n) case nodes.ListenStmt: // pass @@ -992,23 +992,29 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.LockStmt: a.apply(&n, "Relations", nil, n.Relations) + a.cursor.set(n, &n) case nodes.LockingClause: a.apply(&n, "LockedRels", nil, n.LockedRels) + a.cursor.set(n, &n) case nodes.MinMaxExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) case nodes.MultiAssignRef: a.apply(&n, "Source", nil, n.Source) + a.cursor.set(n, &n) case nodes.NamedArgExpr: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Args", nil, n.Arg) + a.cursor.set(n, &n) case nodes.NextValueExpr: a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) case nodes.NotifyStmt: // pass @@ -1019,10 +1025,12 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.NullTest: a.apply(&n, "Xpr", nil, n.Xpr) a.apply(&n, "Arg", nil, n.Arg) + a.cursor.set(n, &n) case nodes.ObjectWithArgs: a.apply(&n, "Objname", nil, n.Objname) a.apply(&n, "Objargs", nil, n.Objargs) + a.cursor.set(n, &n) case nodes.OnConflictClause: if n.Infer != nil { @@ -1030,20 +1038,24 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node } a.apply(&n, "TargetList", nil, n.TargetList) a.apply(&n, "WhereClause", nil, n.WhereClause) + a.cursor.set(n, &n) case nodes.OnConflictExpr: - a.apply(&n, "", nil, n.ArbiterElems) - a.apply(&n, "", nil, n.ArbiterWhere) - a.apply(&n, "", nil, n.OnConflictSet) - a.apply(&n, "", nil, n.OnConflictWhere) - a.apply(&n, "", nil, n.ExclRelTlist) + a.apply(&n, "ArbiterElems", nil, n.ArbiterElems) + a.apply(&n, "ArbiterWhere", nil, n.ArbiterWhere) + a.apply(&n, "OnConflictSet", nil, n.OnConflictSet) + a.apply(&n, "OnConflictWhere", nil, n.OnConflictWhere) + a.apply(&n, "ExclRelTlist", nil, n.ExclRelTlist) + a.cursor.set(n, &n) case nodes.OpExpr: - a.apply(&n, "", nil, n.Xpr) - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Xpr", nil, n.Xpr) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) case nodes.Param: - a.apply(&n, "", nil, n.Xpr) + a.apply(&n, "Xpr", nil, n.Xpr) + a.cursor.set(n, &n) case nodes.ParamExecData: // pass @@ -1058,107 +1070,121 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node // pass case nodes.PartitionBoundSpec: - a.apply(&n, "", nil, n.Listdatums) - a.apply(&n, "", nil, n.Lowerdatums) - a.apply(&n, "", nil, n.Upperdatums) + a.apply(&n, "Listdatums", nil, n.Listdatums) + a.apply(&n, "Lowerdatums", nil, n.Lowerdatums) + a.apply(&n, "Upperdatums", nil, n.Upperdatums) + a.cursor.set(n, &n) case nodes.PartitionCmd: if n.Name != nil { - a.apply(&n, "", nil, *n.Name) + a.apply(&n, "Name", nil, *n.Name) } if n.Bound != nil { - a.apply(&n, "", nil, *n.Bound) + a.apply(&n, "Bound", nil, *n.Bound) } + a.cursor.set(n, &n) case nodes.PartitionElem: - a.apply(&n, "", nil, n.Expr) - a.apply(&n, "", nil, n.Collation) - a.apply(&n, "", nil, n.Opclass) + a.apply(&n, "Expr", nil, n.Expr) + a.apply(&n, "Collation", nil, n.Collation) + a.apply(&n, "Opclass", nil, n.Opclass) + a.cursor.set(n, &n) case nodes.PartitionRangeDatum: - a.apply(&n, "", nil, n.Value) + a.apply(&n, "Value", nil, n.Value) + a.cursor.set(n, &n) case nodes.PartitionSpec: - a.apply(&n, "", nil, n.PartParams) + a.apply(&n, "PartParams", nil, n.PartParams) + a.cursor.set(n, &n) case nodes.PrepareStmt: - a.apply(&n, "", nil, n.Argtypes) - a.apply(&n, "", nil, n.Query) + a.apply(&n, "Argtypes", nil, n.Argtypes) + a.apply(&n, "Query", nil, n.Query) + a.cursor.set(n, &n) case nodes.Query: - a.apply(&n, "", nil, n.UtilityStmt) - a.apply(&n, "", nil, n.CteList) - a.apply(&n, "", nil, n.Jointree) - a.apply(&n, "", nil, n.TargetList) - a.apply(&n, "", nil, n.OnConflict) - a.apply(&n, "", nil, n.ReturningList) - a.apply(&n, "", nil, n.GroupClause) - a.apply(&n, "", nil, n.GroupingSets) - a.apply(&n, "", nil, n.HavingQual) - a.apply(&n, "", nil, n.WindowClause) - a.apply(&n, "", nil, n.DistinctClause) - a.apply(&n, "", nil, n.SortClause) - a.apply(&n, "", nil, n.LimitCount) - a.apply(&n, "", nil, n.RowMarks) - a.apply(&n, "", nil, n.SetOperations) - a.apply(&n, "", nil, n.ConstraintDeps) - a.apply(&n, "", nil, n.WithCheckOptions) + a.apply(&n, "UtilityStmt", nil, n.UtilityStmt) + a.apply(&n, "CteList", nil, n.CteList) + a.apply(&n, "Jointree", nil, n.Jointree) + a.apply(&n, "TargetList", nil, n.TargetList) + a.apply(&n, "OnConflict", nil, n.OnConflict) + a.apply(&n, "ReturningList", nil, n.ReturningList) + a.apply(&n, "GroupClause", nil, n.GroupClause) + a.apply(&n, "GroupingSets", nil, n.GroupingSets) + a.apply(&n, "HavingQual", nil, n.HavingQual) + a.apply(&n, "WindowClause", nil, n.WindowClause) + a.apply(&n, "DistinctClause", nil, n.DistinctClause) + a.apply(&n, "SortClause", nil, n.SortClause) + a.apply(&n, "LimitCount", nil, n.LimitCount) + a.apply(&n, "RowMarks", nil, n.RowMarks) + a.apply(&n, "SetOperations", nil, n.SetOperations) + a.apply(&n, "ConstraintDeps", nil, n.ConstraintDeps) + a.apply(&n, "WithCheckOptions", nil, n.WithCheckOptions) + a.cursor.set(n, &n) case nodes.RangeFunction: - a.apply(&n, "", nil, n.Functions) + a.apply(&n, "Functions", nil, n.Functions) if n.Alias != nil { - a.apply(&n, "", nil, *n.Alias) + a.apply(&n, "Alias", nil, *n.Alias) } - a.apply(&n, "", nil, n.Coldeflist) + a.apply(&n, "Coldeflist", nil, n.Coldeflist) + a.cursor.set(n, &n) case nodes.RangeSubselect: - a.apply(&n, "", nil, n.Subquery) + a.apply(&n, "Subquery", nil, n.Subquery) if n.Alias != nil { - a.apply(&n, "", nil, *n.Alias) + a.apply(&n, "Alias", nil, *n.Alias) } + a.cursor.set(n, &n) case nodes.RangeTableFunc: - a.apply(&n, "", nil, n.Docexpr) - a.apply(&n, "", nil, n.Rowexpr) - a.apply(&n, "", nil, n.Namespaces) - a.apply(&n, "", nil, n.Columns) + a.apply(&n, "Docexpr", nil, n.Docexpr) + a.apply(&n, "Rowexpr", nil, n.Rowexpr) + a.apply(&n, "Namespaces", nil, n.Namespaces) + a.apply(&n, "Columns", nil, n.Columns) if n.Alias != nil { - a.apply(&n, "", nil, *n.Alias) + a.apply(&n, "Alias", nil, *n.Alias) } + a.cursor.set(n, &n) case nodes.RangeTableFuncCol: if n.TypeName != nil { - a.apply(&n, "", nil, *n.TypeName) + a.apply(&n, "TypeName", nil, *n.TypeName) } - a.apply(&n, "", nil, n.Colexpr) - a.apply(&n, "", nil, n.Coldefexpr) + a.apply(&n, "Colexpr", nil, n.Colexpr) + a.apply(&n, "Coldefexpr", nil, n.Coldefexpr) + a.cursor.set(n, &n) case nodes.RangeTableSample: - a.apply(&n, "", nil, n.Relation) - a.apply(&n, "", nil, n.Method) - a.apply(&n, "", nil, n.Args) + a.apply(&n, "Relation", nil, n.Relation) + a.apply(&n, "Method", nil, n.Method) + a.apply(&n, "Args", nil, n.Args) + a.cursor.set(n, &n) case nodes.RangeTblEntry: - a.apply(&n, "", nil, n.Tablesample) - a.apply(&n, "", nil, n.Subquery) - a.apply(&n, "", nil, n.Joinaliasvars) - a.apply(&n, "", nil, n.Functions) - a.apply(&n, "", nil, n.Tablefunc) - a.apply(&n, "", nil, n.ValuesLists) - a.apply(&n, "", nil, n.Coltypes) - a.apply(&n, "", nil, n.Colcollations) + a.apply(&n, "Tablesample", nil, n.Tablesample) + a.apply(&n, "Subquery", nil, n.Subquery) + a.apply(&n, "Joinaliasvars", nil, n.Joinaliasvars) + a.apply(&n, "Functions", nil, n.Functions) + a.apply(&n, "Tablefund", nil, n.Tablefunc) + a.apply(&n, "ValuesLists", nil, n.ValuesLists) + a.apply(&n, "Coltypes", nil, n.Coltypes) + a.apply(&n, "Colcollations", nil, n.Colcollations) if n.Alias != nil { - a.apply(&n, "", nil, *n.Alias) + a.apply(&n, "Alias", nil, *n.Alias) } - a.apply(&n, "", nil, n.Eref) - a.apply(&n, "", nil, n.SecurityQuals) + a.apply(&n, "Eref", nil, n.Eref) + a.apply(&n, "SecurityQuals", nil, n.SecurityQuals) + a.cursor.set(n, &n) case nodes.RangeTblFunction: - a.apply(&n, "", nil, n.Funcexpr) - a.apply(&n, "", nil, n.Funccolnames) - a.apply(&n, "", nil, n.Funccoltypes) - a.apply(&n, "", nil, n.Funccoltypmods) - a.apply(&n, "", nil, n.Funccolcollations) + a.apply(&n, "Funcexpr", nil, n.Funcexpr) + a.apply(&n, "Funccolnames", nil, n.Funccolnames) + a.apply(&n, "Funccoltypes", nil, n.Funccoltypes) + a.apply(&n, "Funccoltypmods", nil, n.Funccoltypmods) + a.apply(&n, "Funccolcollations", nil, n.Funccolcollations) + a.cursor.set(n, &n) case nodes.RangeTblRef: // pass @@ -1166,8 +1192,8 @@ func (a *application) apply(parent nodes.Node, name string, iter *iterator, node case nodes.RangeVar: if n.Alias != nil { a.apply(&n, "Alias", nil, *n.Alias) - a.cursor.set(n, &n) } + a.cursor.set(n, &n) case nodes.RawStmt: a.apply(&n, "Stmt", nil, n.Stmt) From d14a3b3d2dad15ae22bd8ba9f9f4a0944221f520 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 17 Jan 2020 18:21:18 -0800 Subject: [PATCH 11/11] Almost there --- internal/dinosql/checks.go | 23 ++++++---- internal/dinosql/query_test.go | 17 ++++++++ internal/dinosql/rewrite.go | 77 +++++++++++++++++++++++++++++++--- 3 files changed, 104 insertions(+), 13 deletions(-) diff --git a/internal/dinosql/checks.go b/internal/dinosql/checks.go index 7d654ecef3..610acfac7f 100644 --- a/internal/dinosql/checks.go +++ b/internal/dinosql/checks.go @@ -123,18 +123,27 @@ func validateInsertStmt(stmt nodes.InsertStmt) error { return nil } -// A query can either use named parameters (sqlc.arg(param)) or positional -// parameters ($1), but not both +// A query can use one (and only one) of the following formats: +// - positional parameters $1 +// - named parameter operator @param +// - named parameter function calls sqlc.arg(param) func validateParamStyle(n nodes.Node) error { positional := search(n, func(node nodes.Node) bool { _, ok := node.(nodes.ParamRef) return ok }) - named := search(n, isNamedParamFunc) - if len(named.Items) > 0 && len(positional.Items) > 0 { - return pg.Error{ - Code: "", // TODO: Pick a new error code - Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg)", + namedFunc := search(n, isNamedParamFunc) + namedSign := search(n, isNamedParamSign) + for _, check := range []bool{ + len(positional.Items) > 0 && len(namedSign.Items)+len(namedFunc.Items) > 0, + len(namedFunc.Items) > 0 && len(namedSign.Items)+len(positional.Items) > 0, + len(namedSign.Items) > 0 && len(positional.Items)+len(namedFunc.Items) > 0, + } { + if check { + return pg.Error{ + Code: "", // TODO: Pick a new error code + Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)", + } } } return nil diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index b763dd4c0c..9c22e3f086 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -901,6 +901,23 @@ func TestQueries(t *testing.T) { }, }, }, + { + "at_parameter", + ` + CREATE TABLE foo (name text not null); + SELECT name FROM foo WHERE name = @slug AND @filter::bool; + `, + Query{ + SQL: "SELECT name FROM foo WHERE name = $1 AND $2::bool", + Columns: []core.Column{ + {Table: public("foo"), Name: "name", DataType: "text", NotNull: true}, + }, + Params: []Parameter{ + {1, core.Column{Table: public("foo"), Name: "slug", DataType: "text", NotNull: true}}, + {2, core.Column{Name: "filter", DataType: "bool", NotNull: true}}, + }, + }, + }, } { test := tc t.Run(test.name, func(t *testing.T) { diff --git a/internal/dinosql/rewrite.go b/internal/dinosql/rewrite.go index fa363ec952..95b5c457a8 100644 --- a/internal/dinosql/rewrite.go +++ b/internal/dinosql/rewrite.go @@ -32,13 +32,23 @@ func isNamedParamFunc(node nodes.Node) bool { } func isNamedParamSign(node nodes.Node) bool { - fun, ok := node.(nodes.FuncCall) - return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg" + expr, ok := node.(nodes.A_Expr) + return ok && ast.Join(expr.Name, ".") == "@" +} + +func isNamedParamSignCast(node nodes.Node) bool { + expr, ok := node.(nodes.A_Expr) + if !ok { + return false + } + _, cast := expr.Rexpr.(nodes.TypeCast) + return ast.Join(expr.Name, ".") == "@" && cast } func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []edit) { - found := search(raw, isNamedParamFunc) - if len(found.Items) == 0 { + foundFunc := search(raw, isNamedParamFunc) + foundSign := search(raw, isNamedParamSign) + if len(foundFunc.Items)+len(foundSign.Items) == 0 { return raw, map[int]string{}, nil } @@ -47,7 +57,9 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ var edits []edit node := ast.Apply(raw, func(cr *ast.Cursor) bool { node := cr.Node() - if isNamedParamFunc(node) { + switch { + + case isNamedParamFunc(node): fun := node.(nodes.FuncCall) param := flatten(fun.Args) if num, ok := args[param]; ok { @@ -70,8 +82,61 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [ New: fmt.Sprintf("$%d", args[param]), }) return false + + case isNamedParamSignCast(node): + expr := node.(nodes.A_Expr) + cast := expr.Rexpr.(nodes.TypeCast) + param := flatten(cast.Arg) + if num, ok := args[param]; ok { + cast.Arg = nodes.ParamRef{ + Number: num, + Location: expr.Location, + } + cr.Replace(cast) + } else { + argn += 1 + args[param] = argn + cast.Arg = nodes.ParamRef{ + Number: argn, + Location: expr.Location, + } + cr.Replace(cast) + } + // TODO: This code assumes that @foo::bool is on a single line + edits = append(edits, edit{ + Location: expr.Location - raw.StmtLocation, + Old: fmt.Sprintf("@%s", param), + New: fmt.Sprintf("$%d", args[param]), + }) + return false + + case isNamedParamSign(node): + expr := node.(nodes.A_Expr) + param := flatten(expr.Rexpr) + if num, ok := args[param]; ok { + cr.Replace(nodes.ParamRef{ + Number: num, + Location: expr.Location, + }) + } else { + argn += 1 + args[param] = argn + cr.Replace(nodes.ParamRef{ + Number: argn, + Location: expr.Location, + }) + } + // TODO: This code assumes that @foo is on a single line + edits = append(edits, edit{ + Location: expr.Location - raw.StmtLocation, + Old: fmt.Sprintf("@%s", param), + New: fmt.Sprintf("$%d", args[param]), + }) + return false + + default: + return true } - return true }, nil) named := map[int]string{}