diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 65a2d5853c..f8b7150435 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -359,6 +359,7 @@ func isTableRequired(n ast.Node, col *Column, prior int) int { // Return an error if an unknown column is referenced func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { var list *ast.List + var nullableNodes []*ast.Node switch n := node.(type) { case *ast.DeleteStmt: list = &ast.List{ @@ -369,14 +370,29 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { Items: []ast.Node{n.Relation}, } case *ast.SelectStmt: - list = astutils.Search(n.FromClause, func(node ast.Node) bool { - switch node.(type) { + dirtyList := astutils.Search(n.FromClause, func(node ast.Node) bool { + switch t := node.(type) { case *ast.RangeVar, *ast.RangeSubselect, *ast.FuncName: return true + case *ast.JoinExpr: + if t.Jointype == ast.JoinTypeLeft || t.Jointype == ast.JoinTypeFull { + return true + } + return false default: return false } }) + // split result into nullableJoins and list of nodes to handle + list = &ast.List{} + for i, v := range dirtyList.Items { + switch t := dirtyList.Items[i].(type) { + case *ast.JoinExpr: + nullableNodes = append(nullableNodes, &t.Rarg) + default: + list.Items = append(list.Items, v) + } + } case *ast.TruncateStmt: list = astutils.Search(n.Relations, func(node ast.Node) bool { _, ok := node.(*ast.RangeVar) @@ -413,9 +429,15 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) { case *ast.RangeSubselect: cols, err := outputColumns(qc, n.Subquery) + if err != nil { return nil, err } + if astutils.IsChildOfNodes(nullableNodes, &item) { + for _, c := range cols { + c.NotNull = false + } + } tables = append(tables, &Table{ Rel: &ast.TableName{ Name: *n.Alias.Aliasname, diff --git a/internal/endtoend/testdata/join_left/postgresql/go/query.sql.go b/internal/endtoend/testdata/join_left/postgresql/go/query.sql.go index f59a55bf35..a78277b3f2 100644 --- a/internal/endtoend/testdata/join_left/postgresql/go/query.sql.go +++ b/internal/endtoend/testdata/join_left/postgresql/go/query.sql.go @@ -361,6 +361,51 @@ func (q *Queries) GetMayorsOptional(ctx context.Context) ([]GetMayorsOptionalRow return items, nil } +const getMayorsOptionalInnerSelect = `-- name: GetMayorsOptionalInnerSelect :many +SELECT t1.user_id, t2.full_name +FROM ( + SELECT user_id, city_id FROM users WHERE users.city_id = $1 LIMIT 1 OFFSET 0 +) AS t1 +LEFT JOIN cities on t1.city_id = cities.city_id +LEFT JOIN ( + SELECT mayors.mayor_id, mayors.full_name + FROM mayors where mayors.mayor_id = $2 +) AS t2 on t2.mayor_id = cities.mayor_id +` + +type GetMayorsOptionalInnerSelectParams struct { + CityID sql.NullInt32 + MayorID int32 +} + +type GetMayorsOptionalInnerSelectRow struct { + UserID int32 + FullName sql.NullString +} + +func (q *Queries) GetMayorsOptionalInnerSelect(ctx context.Context, arg GetMayorsOptionalInnerSelectParams) ([]GetMayorsOptionalInnerSelectRow, error) { + rows, err := q.db.QueryContext(ctx, getMayorsOptionalInnerSelect, arg.CityID, arg.MayorID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetMayorsOptionalInnerSelectRow + for rows.Next() { + var i GetMayorsOptionalInnerSelectRow + if err := rows.Scan(&i.UserID, &i.FullName); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getSuggestedUsersByID = `-- name: GetSuggestedUsersByID :many SELECT DISTINCT u.user_id, u.user_nickname, u.user_email, u.user_display_name, u.user_password, u.user_google_id, u.user_apple_id, u.user_bio, u.user_created_at, u.user_avatar_id, m.media_id, m.media_created_at, m.media_hash, m.media_directory, m.media_author_id, m.media_width, m.media_height FROM users_2 u diff --git a/internal/endtoend/testdata/join_left/postgresql/query.sql b/internal/endtoend/testdata/join_left/postgresql/query.sql index 816f154aff..f099d988ca 100644 --- a/internal/endtoend/testdata/join_left/postgresql/query.sql +++ b/internal/endtoend/testdata/join_left/postgresql/query.sql @@ -111,3 +111,14 @@ FROM users_2 u ON u.user_avatar_id = m.media_id WHERE u.user_id != @user_id LIMIT @user_imit; + +-- name: GetMayorsOptionalInnerSelect :many +SELECT t1.user_id, t2.full_name +FROM ( + SELECT user_id, city_id FROM users WHERE users.city_id = $1 LIMIT 1 OFFSET 0 +) AS t1 +LEFT JOIN cities on t1.city_id = cities.city_id +LEFT JOIN ( + SELECT mayors.mayor_id, mayors.full_name + FROM mayors where mayors.mayor_id = $2 +) AS t2 on t2.mayor_id = cities.mayor_id; diff --git a/internal/sql/astutils/search.go b/internal/sql/astutils/search.go index 5aeacfb9d9..7d1995ce4c 100644 --- a/internal/sql/astutils/search.go +++ b/internal/sql/astutils/search.go @@ -19,3 +19,25 @@ func Search(root ast.Node, f func(ast.Node) bool) *ast.List { Walk(ns, root) return ns.list } + +func IsChildOfNodes(parents []*ast.Node, node *ast.Node) bool { + for _, v := range parents { + if IsChildOfNode(v, node) { + return true + } + } + return false +} + +func IsChildOfNode(parent *ast.Node, node *ast.Node) bool { + res := Search(*parent, func(n ast.Node) bool { + if n == *node { + return true + } + return false + }) + if len(res.Items) > 0 { + return true + } + return false +}