diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 19128f71ae..e31a4eea9b 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -173,12 +173,17 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er name = *res.Name } notNull := false - if n.Boolop == ast.BoolExprTypeNot && len(n.Args.Items) == 1 { - sublink, ok := n.Args.Items[0].(*ast.SubLink) - if ok && sublink.SubLinkType == ast.EXISTS_SUBLINK { + if len(n.Args.Items) == 1 { + switch n.Boolop { + case ast.BoolExprTypeIsNull, ast.BoolExprTypeIsNotNull: notNull = true - if name == "" { - name = "not_exists" + case ast.BoolExprTypeNot: + sublink, ok := n.Args.Items[0].(*ast.SubLink) + if ok && sublink.SubLinkType == ast.EXISTS_SUBLINK { + notNull = true + if name == "" { + name = "not_exists" + } } } } diff --git a/internal/endtoend/testdata/comparisons/mysql/go/query.sql.go b/internal/endtoend/testdata/comparisons/mysql/go/query.sql.go index cd77cc1b42..b823f5f383 100644 --- a/internal/endtoend/testdata/comparisons/mysql/go/query.sql.go +++ b/internal/endtoend/testdata/comparisons/mysql/go/query.sql.go @@ -117,6 +117,60 @@ func (q *Queries) GreaterThanOrEqual(ctx context.Context) ([]bool, error) { return items, nil } +const isNotNull = `-- name: IsNotNull :many +SELECT id IS NOT NULL FROM bar +` + +func (q *Queries) IsNotNull(ctx context.Context) ([]bool, error) { + rows, err := q.db.QueryContext(ctx, isNotNull) + if err != nil { + return nil, err + } + defer rows.Close() + var items []bool + for rows.Next() { + var column_1 bool + if err := rows.Scan(&column_1); err != nil { + return nil, err + } + items = append(items, column_1) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const isNull = `-- name: IsNull :many +SELECT id IS NULL FROM bar +` + +func (q *Queries) IsNull(ctx context.Context) ([]bool, error) { + rows, err := q.db.QueryContext(ctx, isNull) + if err != nil { + return nil, err + } + defer rows.Close() + var items []bool + for rows.Next() { + var column_1 bool + if err := rows.Scan(&column_1); err != nil { + return nil, err + } + items = append(items, column_1) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const lessThan = `-- name: LessThan :many SELECT count(*) < 0 FROM bar ` diff --git a/internal/endtoend/testdata/comparisons/mysql/query.sql b/internal/endtoend/testdata/comparisons/mysql/query.sql index 3c609f4660..73714ec749 100644 --- a/internal/endtoend/testdata/comparisons/mysql/query.sql +++ b/internal/endtoend/testdata/comparisons/mysql/query.sql @@ -24,8 +24,8 @@ SELECT count(*) <> 0 FROM bar; -- name: Equal :many SELECT count(*) = 0 FROM bar; +-- name: IsNull :many +SELECT id IS NULL FROM bar; - - - - +-- name: IsNotNull :many +SELECT id IS NOT NULL FROM bar; diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index 6b03774ebc..4a4478c65a 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -953,7 +953,18 @@ func (c *cc) convertIndexPartSpecification(n *pcast.IndexPartSpecification) ast. } func (c *cc) convertIsNullExpr(n *pcast.IsNullExpr) ast.Node { - return todo(n) + op := ast.BoolExprTypeIsNull + if n.Not { + op = ast.BoolExprTypeIsNotNull + } + return &ast.BoolExpr{ + Boolop: op, + Args: &ast.List{ + Items: []ast.Node{ + c.convert(n.Expr), + }, + }, + } } func (c *cc) convertIsTruthExpr(n *pcast.IsTruthExpr) ast.Node { diff --git a/internal/sql/ast/bool_expr_type.go b/internal/sql/ast/bool_expr_type.go index 0296a9691d..7a4068d102 100644 --- a/internal/sql/ast/bool_expr_type.go +++ b/internal/sql/ast/bool_expr_type.go @@ -6,6 +6,10 @@ const ( BoolExprTypeAnd BoolExprTypeOr BoolExprTypeNot + + // Added for MySQL + BoolExprTypeIsNull + BoolExprTypeIsNotNull ) type BoolExprType uint