diff --git a/internal/compiler/expand.go b/internal/compiler/expand.go index 66862e33fb..da2c078a7d 100644 --- a/internal/compiler/expand.go +++ b/internal/compiler/expand.go @@ -55,7 +55,7 @@ func (c *Compiler) quoteIdent(ident string) string { } func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) { - tables, err := sourceTables(qc, node) + tables, err := c.sourceTables(qc, node) if err != nil { return nil, err } diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index f71fc3b461..36cafa3b22 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -14,11 +14,11 @@ import ( // OutputColumns determines which columns a statement will output func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) { - qc, err := buildQueryCatalog(c.catalog, stmt, nil) + qc, err := c.buildQueryCatalog(c.catalog, stmt, nil) if err != nil { return nil, err } - cols, err := outputColumns(qc, stmt) + cols, err := c.outputColumns(qc, stmt) if err != nil { return nil, err } @@ -51,8 +51,8 @@ func hasStarRef(cf *ast.ColumnRef) bool { // // Return an error if column references are ambiguous // Return an error if column references don't exist -func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { - tables, err := sourceTables(qc, node) +func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { + tables, err := c.sourceTables(qc, node) if err != nil { return nil, err } @@ -68,21 +68,50 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { if n.GroupClause != nil { for _, item := range n.GroupClause.Items { - ref, ok := item.(*ast.ColumnRef) - if !ok { - continue - } - - if err := findColumnForRef(ref, tables, n); err != nil { + if err := findColumnForNode(item, tables, n); err != nil { return nil, err } } } + validateOrderBy := true + if c.conf.StrictOrderBy != nil { + validateOrderBy = *c.conf.StrictOrderBy + } + if validateOrderBy { + if n.SortClause != nil { + for _, item := range n.SortClause.Items { + sb, ok := item.(*ast.SortBy) + if !ok { + continue + } + if err := findColumnForNode(sb.Node, tables, n); err != nil { + return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err) + } + } + } + if n.WindowClause != nil { + for _, item := range n.WindowClause.Items { + sb, ok := item.(*ast.List) + if !ok { + continue + } + for _, single := range sb.Items { + caseExpr, ok := single.(*ast.CaseExpr) + if !ok { + continue + } + if err := findColumnForNode(caseExpr.Xpr, tables, n); err != nil { + return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err) + } + } + } + } + } // For UNION queries, targets is empty and we need to look for the // columns in Largs. if len(targets.Items) == 0 && n.Larg != nil { - return outputColumns(qc, n.Larg) + return c.outputColumns(qc, n.Larg) } case *ast.CallStmt: targets = &ast.List{} @@ -303,7 +332,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { case ast.EXISTS_SUBLINK: cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true}) case ast.EXPR_SUBLINK: - subcols, err := outputColumns(qc, n.Subselect) + subcols, err := c.outputColumns(qc, n.Subselect) if err != nil { return nil, err } @@ -339,7 +368,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) { cols = append(cols, col) case *ast.SelectStmt: - subcols, err := outputColumns(qc, n) + subcols, err := c.outputColumns(qc, n) if err != nil { return nil, err } @@ -428,7 +457,7 @@ func isTableRequired(n ast.Node, col *Column, prior int) int { // Return an error if column references don't exist // Return an error if a table is referenced twice // Return an error if an unknown column is referenced -func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { +func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { var list *ast.List switch n := node.(type) { case *ast.DeleteStmt: @@ -483,7 +512,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { tables = append(tables, table) case *ast.RangeSubselect: - cols, err := outputColumns(qc, n.Subquery) + cols, err := c.outputColumns(qc, n.Subquery) if err != nil { return nil, err } @@ -581,6 +610,14 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef) return cols, nil } +func findColumnForNode(item ast.Node, tables []*Table, n *ast.SelectStmt) error { + ref, ok := item.(*ast.ColumnRef) + if !ok { + return nil + } + return findColumnForRef(ref, tables, n) +} + func findColumnForRef(ref *ast.ColumnRef, tables []*Table, selectStatement *ast.SelectStmt) error { parts := stringSlice(ref.Fields) var alias, name string diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 2ecff7068a..c7b6f7f3f8 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -86,9 +86,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, } else { sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) } - raw, embeds := rewrite.Embeds(raw) - qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds) + qc, err := c.buildQueryCatalog(c.catalog, raw.Stmt, embeds) if err != nil { return nil, err } @@ -97,7 +96,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, if err != nil { return nil, err } - cols, err := outputColumns(qc, raw.Stmt) + cols, err := c.outputColumns(qc, raw.Stmt) if err != nil { return nil, err } diff --git a/internal/compiler/query_catalog.go b/internal/compiler/query_catalog.go index 96131de729..2b6577c2e9 100644 --- a/internal/compiler/query_catalog.go +++ b/internal/compiler/query_catalog.go @@ -14,7 +14,7 @@ type QueryCatalog struct { embeds rewrite.EmbedSet } -func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) { +func (comp *Compiler) buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) { var with *ast.WithClause switch n := node.(type) { case *ast.DeleteStmt: @@ -32,7 +32,7 @@ func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSe if with != nil { for _, item := range with.Ctes.Items { if cte, ok := item.(*ast.CommonTableExpr); ok { - cols, err := outputColumns(qc, cte.Ctequery) + cols, err := comp.outputColumns(qc, cte.Ctequery) if err != nil { return nil, err } diff --git a/internal/config/config.go b/internal/config/config.go index a0b1a8f382..b0479a6f12 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -99,6 +99,7 @@ type SQL struct { Schema Paths `json:"schema" yaml:"schema"` Queries Paths `json:"queries" yaml:"queries"` StrictFunctionChecks bool `json:"strict_function_checks" yaml:"strict_function_checks"` + StrictOrderBy *bool `json:"strict_order_by" yaml:"strict_order_by"` Gen SQLGen `json:"gen" yaml:"gen"` Codegen []Codegen `json:"codegen" yaml:"codegen"` } diff --git a/internal/config/v_one.go b/internal/config/v_one.go index 369b8ac457..b0f1bfae04 100644 --- a/internal/config/v_one.go +++ b/internal/config/v_one.go @@ -46,6 +46,7 @@ type v1PackageSettings struct { OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"` OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"` StrictFunctionChecks bool `json:"strict_function_checks" yaml:"strict_function_checks"` + StrictOrderBy *bool `json:"strict_order_by" yaml:"strict_order_by"` QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"` } @@ -130,6 +131,10 @@ func (c *V1GenerateSettings) Translate() Config { } for _, pkg := range c.Packages { + if pkg.StrictOrderBy == nil { + defaultValue := true + pkg.StrictOrderBy = &defaultValue + } conf.SQL = append(conf.SQL, SQL{ Engine: pkg.Engine, Schema: pkg.Schema, @@ -164,6 +169,7 @@ func (c *V1GenerateSettings) Translate() Config { }, }, StrictFunctionChecks: pkg.StrictFunctionChecks, + StrictOrderBy: pkg.StrictOrderBy, }) } diff --git a/internal/config/v_two.go b/internal/config/v_two.go index fb86b14446..e2c97a2749 100644 --- a/internal/config/v_two.go +++ b/internal/config/v_two.go @@ -110,6 +110,10 @@ func v2ParseConfig(rd io.Reader) (Config, error) { return conf, ErrPluginNotFound } } + if conf.SQL[j].StrictOrderBy == nil { + defaultValidate := true + conf.SQL[j].StrictOrderBy = &defaultValidate + } } return conf, nil } diff --git a/internal/endtoend/testdata/order_by_non_existing_column/mysql/query.sql b/internal/endtoend/testdata/order_by_non_existing_column/mysql/query.sql new file mode 100644 index 0000000000..b1a4b4f638 --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/mysql/query.sql @@ -0,0 +1,8 @@ +-- Example queries for sqlc +CREATE TABLE authors ( + id INT +); + +-- name: ListAuthors :many +SELECT id FROM authors +ORDER BY adfadsf; \ No newline at end of file diff --git a/internal/endtoend/testdata/order_by_non_existing_column/mysql/sqlc.yaml b/internal/endtoend/testdata/order_by_non_existing_column/mysql/sqlc.yaml new file mode 100644 index 0000000000..c4b3831631 --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/mysql/sqlc.yaml @@ -0,0 +1,7 @@ +version: 1 +packages: + - path: "go" + name: "querytest" + engine: "postgresql" + schema: "query.sql" + queries: "query.sql" \ No newline at end of file diff --git a/internal/endtoend/testdata/order_by_non_existing_column/mysql/stderr.txt b/internal/endtoend/testdata/order_by_non_existing_column/mysql/stderr.txt new file mode 100644 index 0000000000..166178156e --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/mysql/stderr.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false diff --git a/internal/endtoend/testdata/order_by_non_existing_column/postgresql/query.sql b/internal/endtoend/testdata/order_by_non_existing_column/postgresql/query.sql new file mode 100644 index 0000000000..b1a4b4f638 --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/postgresql/query.sql @@ -0,0 +1,8 @@ +-- Example queries for sqlc +CREATE TABLE authors ( + id INT +); + +-- name: ListAuthors :many +SELECT id FROM authors +ORDER BY adfadsf; \ No newline at end of file diff --git a/internal/endtoend/testdata/order_by_non_existing_column/postgresql/sqlc.yaml b/internal/endtoend/testdata/order_by_non_existing_column/postgresql/sqlc.yaml new file mode 100644 index 0000000000..c4b3831631 --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/postgresql/sqlc.yaml @@ -0,0 +1,7 @@ +version: 1 +packages: + - path: "go" + name: "querytest" + engine: "postgresql" + schema: "query.sql" + queries: "query.sql" \ No newline at end of file diff --git a/internal/endtoend/testdata/order_by_non_existing_column/postgresql/stderr.txt b/internal/endtoend/testdata/order_by_non_existing_column/postgresql/stderr.txt new file mode 100644 index 0000000000..166178156e --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/postgresql/stderr.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false diff --git a/internal/endtoend/testdata/order_by_non_existing_column/sqlite/query.sql b/internal/endtoend/testdata/order_by_non_existing_column/sqlite/query.sql new file mode 100644 index 0000000000..b1a4b4f638 --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/sqlite/query.sql @@ -0,0 +1,8 @@ +-- Example queries for sqlc +CREATE TABLE authors ( + id INT +); + +-- name: ListAuthors :many +SELECT id FROM authors +ORDER BY adfadsf; \ No newline at end of file diff --git a/internal/endtoend/testdata/order_by_non_existing_column/sqlite/sqlc.yaml b/internal/endtoend/testdata/order_by_non_existing_column/sqlite/sqlc.yaml new file mode 100644 index 0000000000..c4b3831631 --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/sqlite/sqlc.yaml @@ -0,0 +1,7 @@ +version: 1 +packages: + - path: "go" + name: "querytest" + engine: "postgresql" + schema: "query.sql" + queries: "query.sql" \ No newline at end of file diff --git a/internal/endtoend/testdata/order_by_non_existing_column/sqlite/stderr.txt b/internal/endtoend/testdata/order_by_non_existing_column/sqlite/stderr.txt new file mode 100644 index 0000000000..166178156e --- /dev/null +++ b/internal/endtoend/testdata/order_by_non_existing_column/sqlite/stderr.txt @@ -0,0 +1,2 @@ +# package querytest +query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false