diff --git a/prql-compiler/src/sql/gen_expr.rs b/prql-compiler/src/sql/gen_expr.rs index 0f4961c1f7a3..5cc21bceca68 100644 --- a/prql-compiler/src/sql/gen_expr.rs +++ b/prql-compiler/src/sql/gen_expr.rs @@ -22,7 +22,6 @@ use crate::sql::context::ColumnDecl; use crate::utils::OrMap; use super::gen_projection::try_into_exprs; -use super::std::*; use super::Context; pub(super) fn translate_expr(expr: Expr, ctx: &mut Context) -> Result { @@ -76,140 +75,149 @@ pub(super) fn translate_expr(expr: Expr, ctx: &mut Context) -> Result { - let expr = match try_into_is_null(expr, ctx)? { - Ok(is_null) => return Ok(is_null), - Err(expr) => expr, - }; - - let expr = match try_into_between(expr, ctx)? { - Ok(between) => return Ok(between), - Err(expr) => expr, - }; - - let expr = match try_into_concat_function(expr, ctx)? { - Ok(between) => return Ok(between), - Err(expr) => expr, - }; - - let expr = match try_into_regex_function(expr, ctx)? { - Ok(between) => return Ok(between), - Err(expr) => expr, - }; - - let expr = match try_into_binary_op(expr, ctx)? { - Ok(bin_op) => return Ok(bin_op), - Err(expr) => expr, - }; - - let expr = match try_into_unary_op(expr, ctx)? { - Ok(un_op) => return Ok(un_op), - Err(expr) => expr, - }; - + ExprKind::BuiltInFunction { ref name, ref args } => { + // A few special cases and then fall-through to the standard approach. + match name.as_str() { + // See notes in `std.rs` re whether we use names vs. + // `FunctionDecl` vs. an Enum; and getting the correct + // number of args from there. Currently the error messages + // for the wrong number of args will be bad (though it's an + // unusual case where RQ contains something like `std.eq` + // with the wrong number of args). + "std.eq" | "std.ne" => { + if let [a, b] = args.as_slice() { + if a.kind == ExprKind::Literal(Literal::Null) + || b.kind == ExprKind::Literal(Literal::Null) + { + return process_null(name, args, ctx); + } else if let Some(op) = operator_from_name(name) { + return translate_binary_operator(a, b, op, ctx); + } + } + } + "std.neg" | "std.not" => { + if let [arg] = args.as_slice() { + return process_unary(name, arg, ctx); + } + } + "std.concat" => return process_concat(&expr, ctx), + "std.regex_search" => { + if let [search, target] = args.as_slice() { + return process_regex(search, target, ctx); + } + } + _ => match try_into_between(expr.clone(), ctx)? { + Some(between_expr) => return Ok(between_expr), + None => { + if let Some(op) = operator_from_name(name) { + if let [left, right] = args.as_slice() { + return translate_binary_operator(left, right, op, ctx); + } + } + } + }, + } super::std::translate_built_in(expr, ctx)? } }) } -fn try_into_binary_op(expr: Expr, ctx: &mut Context) -> Result> { - use BinaryOperator::*; - const DECLS: [super::std::FunctionDecl<2>; 14] = [ - STD_MUL, STD_DIV, STD_MOD, STD_ADD, STD_SUB, STD_EQ, STD_NE, STD_GT, STD_LT, STD_GTE, - STD_LTE, STD_AND, STD_OR, STD_CONCAT, - ]; - const OPS: [BinaryOperator; 14] = [ - Multiply, - Divide, - Modulo, - Plus, - Minus, - Eq, - NotEq, - Gt, - Lt, - GtEq, - LtEq, - And, - Or, - StringConcat, - ]; - - let Some((decl, _)) = try_unpack(&expr, DECLS)? else { - return Ok(Err(expr)); - }; - - // this lookup is O(N), but 13 is not that big of a N - let decl_index = DECLS.iter().position(|x| x == &decl).unwrap(); - let op = OPS[decl_index].clone(); - let [left, right] = unpack(expr, decl); - - let strength = op.binding_strength(); - let left = translate_operand(left, strength, !op.associates_left(), ctx)?; - let right = translate_operand(right, strength, !op.associates_right(), ctx)?; - Ok(Ok(sql_ast::Expr::BinaryOp { left, right, op })) -} - -fn try_into_unary_op(expr: Expr, ctx: &mut Context) -> Result> { - use UnaryOperator::*; - const DECLS: [super::std::FunctionDecl<1>; 2] = [STD_NEG, STD_NOT]; - const OPS: [UnaryOperator; 2] = [Minus, Not]; - - let Some((decl, _)) = try_unpack(&expr, DECLS)? else { - return Ok(Err(expr)); +/// Translates into IS NULL if possible +fn process_null(name: &str, args: &[Expr], ctx: &mut Context) -> Result { + let (a, b) = (&args[0], &args[1]); + let operand = if matches!(a.kind, ExprKind::Literal(Literal::Null)) { + b + } else { + a }; - // this lookup is O(N), but 13 is not that big of a N - let decl_index = DECLS.iter().position(|x| x == &decl).unwrap(); - let op = OPS[decl_index]; - let [arg] = unpack(expr, decl); - let expr = translate_operand(arg, op.binding_strength(), false, ctx)?; - Ok(Ok(sql_ast::Expr::UnaryOp { op, expr })) + // If this were an Enum, we could match on it (see notes in `std.rs`). + if name == "std.eq" { + let strength = + sql_ast::Expr::IsNull(Box::new(sql_ast::Expr::Value(Value::Null))).binding_strength(); + let expr = translate_operand(operand.clone(), strength, false, ctx)?; + Ok(sql_ast::Expr::IsNull(expr)) + } else if name == "std.ne" { + let strength = sql_ast::Expr::IsNotNull(Box::new(sql_ast::Expr::Value(Value::Null))) + .binding_strength(); + let expr = translate_operand(operand.clone(), strength, false, ctx)?; + Ok(sql_ast::Expr::IsNotNull(expr)) + } else { + unreachable!() + } } -fn try_into_concat_function(expr: Expr, ctx: &mut Context) -> Result> { - if !ctx.dialect.has_concat_function() { - return Ok(Err(expr)); +fn process_unary(name: &str, arg: &Expr, ctx: &mut Context) -> Result { + match name { + "std.neg" => { + let expr = translate_operand( + arg.clone(), + UnaryOperator::Minus.binding_strength(), + false, + ctx, + )?; + Ok(sql_ast::Expr::UnaryOp { + op: UnaryOperator::Minus, + expr, + }) + } + "std.not" => { + let expr = translate_operand( + arg.clone(), + UnaryOperator::Not.binding_strength(), + false, + ctx, + )?; + Ok(sql_ast::Expr::UnaryOp { + op: UnaryOperator::Not, + expr, + }) + } + _ => unreachable!(), // We've already covered all cases above } - - let args = match try_unpack_concat(expr)? { - Ok(args) => args, - Err(expr) => return Ok(Err(expr)), - }; - - let args = args - .into_iter() - .map(|a| { - translate_expr(a, ctx) - .map(FunctionArgExpr::Expr) - .map(FunctionArg::Unnamed) - }) - .try_collect()?; - - Ok(Ok(sql_ast::Expr::Function(Function { - name: ObjectName(vec![sql_ast::Ident::new("CONCAT")]), - args, - over: None, - distinct: false, - special: false, - }))) } -fn try_into_regex_function(expr: Expr, ctx: &mut Context) -> Result> { - // This function is mostly copied from the other `try_into_*` functions — - // don't use this as a template. - // - // Possibly we might be able to simplify some of this, even if it's - // more verbose / less performant? It's not easy rust to add a simple - // function. But possibly we keep it complicated here and allow for more - // implementations in PRQL std lib. - - const DECLS: [super::std::FunctionDecl<2>; 1] = [STD_REGEX_SEARCH]; +fn process_concat(expr: &Expr, ctx: &mut Context) -> Result { + if ctx.dialect.has_concat_function() { + let concat_args = collect_concat_args(expr); + + let args = concat_args + .iter() + .map(|a| { + translate_expr((*a).clone(), ctx) + .map(FunctionArgExpr::Expr) + .map(FunctionArg::Unnamed) + }) + .try_collect()?; + + Ok(sql_ast::Expr::Function(Function { + name: ObjectName(vec![sql_ast::Ident::new("CONCAT")]), + args, + over: None, + distinct: false, + special: false, + })) + } else { + let concat_args = collect_concat_args(expr); + + let mut iter = concat_args.into_iter(); + let first_expr = iter.next().unwrap(); + let mut current_expr = translate_expr(first_expr.clone(), ctx)?; + + for arg in iter { + let translated_arg = translate_expr(arg.clone(), ctx)?; + current_expr = sql_ast::Expr::BinaryOp { + left: Box::new(current_expr), + op: BinaryOperator::StringConcat, + right: Box::new(translated_arg), + }; + } - let Some((decl, _)) = try_unpack(&expr, DECLS)? else { - return Ok(Err(expr)); - }; + Ok(current_expr) + } +} +fn process_regex(search: &Expr, target: &Expr, ctx: &mut Context) -> Result { let Some(regex_function) = ctx.dialect.regex_function() else { // TODO: name the dialect, but not immediately obvious how to actually // get the dialect string from a `DialectHandler`. @@ -218,38 +226,102 @@ fn try_into_regex_function(expr: Expr, ctx: &mut Context) -> Result Result, Expr>> { - let Some((_, _)) = try_unpack(&expr, [STD_CONCAT])? else { - return Ok(Err(expr)); - }; - let [left, right] = unpack(expr, STD_CONCAT); +fn translate_binary_operator( + left: &Expr, + right: &Expr, + op: BinaryOperator, + ctx: &mut Context, +) -> Result { + let strength = op.binding_strength(); + let left = translate_operand(left.clone(), strength, !op.associates_left(), ctx)?; + let right = translate_operand(right.clone(), strength, !op.associates_right(), ctx)?; - let mut args = match try_unpack_concat(left)? { - Ok(args) => args, - Err(left) => vec![left], - }; - args.push(right); - Ok(Ok(args)) + Ok(sql_ast::Expr::BinaryOp { left, op, right }) +} + +fn collect_concat_args(expr: &Expr) -> Vec<&Expr> { + match &expr.kind { + ExprKind::BuiltInFunction { name, args } if name == "std.concat" => { + args.iter().flat_map(collect_concat_args).collect() + } + _ => vec![expr], + } +} + +/// Translate expr into a BETWEEN statement if possible, otherwise returns the expr unchanged. +fn try_into_between(expr: Expr, ctx: &mut Context) -> Result, anyhow::Error> { + if let ExprKind::BuiltInFunction { name, args } = &expr.kind { + if name == "std.and" { + if let [a, b] = args.as_slice() { + if let ( + ExprKind::BuiltInFunction { + name: a_name, + args: a_args, + }, + ExprKind::BuiltInFunction { + name: b_name, + args: b_args, + }, + ) = (&a.kind, &b.kind) + { + if a_name == "std.gte" && b_name == "std.lte" { + if let ([a_l, a_r], [b_l, b_r]) = (a_args.as_slice(), b_args.as_slice()) { + // We need for the values on each arm to be the same; e.g. x + // > 3 and x < 5 + if a_l == b_l { + return Ok(Some(sql_ast::Expr::Between { + expr: translate_operand(a_l.clone(), 0, false, ctx)?, + negated: false, + low: translate_operand(a_r.clone(), 0, false, ctx)?, + high: translate_operand(b_r.clone(), 0, false, ctx)?, + })); + } + } + } + } + } + } + } + Ok(None) +} + +fn operator_from_name(name: &str) -> Option { + use BinaryOperator::*; + match name { + "std.mul" => Some(Multiply), + "std.div" => Some(Divide), + "std.mod" => Some(Modulo), + "std.add" => Some(Plus), + "std.sub" => Some(Minus), + "std.eq" => Some(Eq), + "std.ne" => Some(NotEq), + "std.gt" => Some(Gt), + "std.lt" => Some(Lt), + "std.gte" => Some(GtEq), + "std.lte" => Some(LtEq), + "std.and" => Some(And), + "std.or" => Some(Or), + "std.concat" => Some(StringConcat), + _ => None, + } } pub(super) fn translate_literal(l: Literal, ctx: &Context) -> Result { @@ -575,74 +647,6 @@ pub(super) fn translate_select_item(cid: CId, ctx: &mut Context) -> Result Result> { - let Some((decl, [a, b])) = try_unpack(&expr, [STD_EQ, STD_NE])? else { - return Ok(Err(expr)) - }; - - let take_a = if matches!(a.kind, ExprKind::Literal(Literal::Null)) { - false - } else if matches!(b.kind, ExprKind::Literal(Literal::Null)) { - true - } else { - return Ok(Err(expr)); - }; - let is_std_eq = decl == STD_EQ; - - // we are sure this translates to IS NULL - let [a, b] = unpack(expr, decl); - let operand = if take_a { a } else { b }; - - let strength = - sql_ast::Expr::IsNull(Box::new(sql_ast::Expr::Value(Value::Null))).binding_strength(); - let expr = translate_operand(operand, strength, false, ctx)?; - - Ok(Ok(if is_std_eq { - sql_ast::Expr::IsNull(expr) - } else { - sql_ast::Expr::IsNotNull(expr) - })) -} - -/// Translate expr into a BETWEEN statement if possible, otherwise returns the expr unchanged. -/// -/// Outer Result contains an error, inner Result contains the unmatched expr. -fn try_into_between(expr: Expr, ctx: &mut Context) -> Result> { - // validate that this expr matches the criteria - - let Some((_, [a, b])) = try_unpack(&expr, [STD_AND])? else { - return Ok(Err(expr)); - }; - - let Some((_, [a_l, _a_r])) = try_unpack(a, [STD_GTE])? else { - return Ok(Err(expr)); - }; - let Some((_, [b_l, _b_r])) = try_unpack(b, [STD_LTE])? else { - return Ok(Err(expr)); - }; - - if a_l != b_l { - return Ok(Err(expr)); - } - - // at this point we are sure that this should translate to Between - // so the expr can be unpacked for good - - let [a, b] = unpack(expr, STD_AND); - let [a_l, a_r] = unpack(a, STD_GTE); - let [_, b_r] = unpack(b, STD_LTE); - - Ok(Ok(sql_ast::Expr::Between { - expr: translate_operand(a_l, 0, false, ctx)?, - negated: false, - low: translate_operand(a_r, 0, false, ctx)?, - high: translate_operand(b_r, 0, false, ctx)?, - })) -} - fn translate_windowed( expr: sql_ast::Expr, window: Window, diff --git a/prql-compiler/src/sql/preprocess.rs b/prql-compiler/src/sql/preprocess.rs index 8383beb90a2c..7b1786b379c8 100644 --- a/prql-compiler/src/sql/preprocess.rs +++ b/prql-compiler/src/sql/preprocess.rs @@ -436,17 +436,25 @@ fn collect_equals(expr: &Expr) -> Result<(Vec<&Expr>, Vec<&Expr>)> { let mut lefts = Vec::new(); let mut rights = Vec::new(); - if let Some((_, [left, right])) = super::std::try_unpack(expr, [super::std::STD_EQ])? { - lefts.push(left); - rights.push(right); - } else if let Some((_, [left, right])) = super::std::try_unpack(expr, [super::std::STD_AND])? { - let (l, r) = collect_equals(left)?; - lefts.extend(l); - rights.extend(r); - - let (l, r) = collect_equals(right)?; - lefts.extend(l); - rights.extend(r); + match &expr.kind { + ExprKind::BuiltInFunction { name, args } + if name == super::std::STD_EQ.name && args.len() == 2 => + { + lefts.push(&args[0]); + rights.push(&args[1]); + } + ExprKind::BuiltInFunction { name, args } + if name == super::std::STD_AND.name && args.len() == 2 => + { + let (l, r) = collect_equals(&args[0])?; + lefts.extend(l); + rights.extend(r); + + let (l, r) = collect_equals(&args[1])?; + lefts.extend(l); + rights.extend(r); + } + _ => (), } Ok((lefts, rights)) @@ -514,19 +522,26 @@ impl RqFold for Normalizer { ..expr }; - let Some((decl, _)) = super::std::try_unpack(&expr, [super::std::STD_EQ])? else { - return Ok(expr); - }; - let name = decl.name.to_string(); - let span = expr.span; - let [left, right] = super::std::unpack(expr, decl); + if let ExprKind::BuiltInFunction { name, args } = &expr.kind { + if name == "std.eq" && args.len() == 2 { + let (left, right) = (&args[0], &args[1]); + let span = expr.span; + let new_args = if let ExprKind::Literal(Literal::Null) = &left.kind { + vec![right.clone(), left.clone()] + } else { + vec![left.clone(), right.clone()] + }; + let new_kind = ExprKind::BuiltInFunction { + name: name.clone(), + args: new_args, + }; + return Ok(Expr { + kind: new_kind, + span, + }); + } + } - let args = if let ExprKind::Literal(Literal::Null) = &left.kind { - vec![right, left] - } else { - vec![left, right] - }; - let kind = ExprKind::BuiltInFunction { name, args }; - Ok(Expr { kind, span }) + Ok(expr) } } diff --git a/prql-compiler/src/sql/std.rs b/prql-compiler/src/sql/std.rs index cbf059aae327..b08ef02e1c2f 100644 --- a/prql-compiler/src/sql/std.rs +++ b/prql-compiler/src/sql/std.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use std::iter::zip; use anyhow::Result; -use itertools::Itertools; use once_cell::sync::Lazy; use sqlparser::ast::{self as sql_ast}; @@ -80,65 +79,28 @@ impl FunctionDecl { } } +// TODO: We're not using many of these, and instead matching on the name now. +// Some options: +// - Go back to matching on the defined `FunctionDecl`s, uncomment these +// - Make these into an Enum — would make some matching simpler +// - Separate the operators out into an Enum structure (and possibly the binary +// from the unary ones?) + // TODO: automatically generate these definitions from std_impl.prql -pub(crate) const STD_MUL: FunctionDecl<2> = FunctionDecl::new("std.mul"); -pub(crate) const STD_DIV: FunctionDecl<2> = FunctionDecl::new("std.div"); -pub(crate) const STD_MOD: FunctionDecl<2> = FunctionDecl::new("std.mod"); -pub(crate) const STD_ADD: FunctionDecl<2> = FunctionDecl::new("std.add"); -pub(crate) const STD_SUB: FunctionDecl<2> = FunctionDecl::new("std.sub"); +// pub(crate) const STD_MUL: FunctionDecl<2> = FunctionDecl::new("std.mul"); +// pub(crate) const STD_DIV: FunctionDecl<2> = FunctionDecl::new("std.div"); +// pub(crate) const STD_MOD: FunctionDecl<2> = FunctionDecl::new("std.mod"); +// pub(crate) const STD_ADD: FunctionDecl<2> = FunctionDecl::new("std.add"); +// pub(crate) const STD_SUB: FunctionDecl<2> = FunctionDecl::new("std.sub"); pub(crate) const STD_EQ: FunctionDecl<2> = FunctionDecl::new("std.eq"); -pub(crate) const STD_NE: FunctionDecl<2> = FunctionDecl::new("std.ne"); -pub(crate) const STD_GT: FunctionDecl<2> = FunctionDecl::new("std.gt"); -pub(crate) const STD_LT: FunctionDecl<2> = FunctionDecl::new("std.lt"); +// pub(crate) const STD_NE: FunctionDecl<2> = FunctionDecl::new("std.ne"); +// pub(crate) const STD_GT: FunctionDecl<2> = FunctionDecl::new("std.gt"); +// pub(crate) const STD_LT: FunctionDecl<2> = FunctionDecl::new("std.lt"); pub(crate) const STD_GTE: FunctionDecl<2> = FunctionDecl::new("std.gte"); pub(crate) const STD_LTE: FunctionDecl<2> = FunctionDecl::new("std.lte"); -pub(crate) const STD_REGEX_SEARCH: FunctionDecl<2> = FunctionDecl::new("std.regex_search"); +// pub(crate) const STD_REGEX_SEARCH: FunctionDecl<2> = FunctionDecl::new("std.regex_search"); pub(crate) const STD_AND: FunctionDecl<2> = FunctionDecl::new("std.and"); -pub(crate) const STD_OR: FunctionDecl<2> = FunctionDecl::new("std.or"); +// pub(crate) const STD_OR: FunctionDecl<2> = FunctionDecl::new("std.or"); pub(crate) const STD_CONCAT: FunctionDecl<2> = FunctionDecl::new("std.concat"); -pub(crate) const STD_NEG: FunctionDecl<1> = FunctionDecl::new("std.neg"); -pub(crate) const STD_NOT: FunctionDecl<1> = FunctionDecl::new("std.not"); - -/// Assumes the expr is: -/// - [rq::ExprKind::BuiltInFunction], -/// - name matches the function decl, -/// - number of arguments matches the function decl. -/// Returns the unpacked arguments. Panics if any of the assumptions are not met. -/// -/// This function should probably be called after the expr was validated with [try_unpack_ref]. -pub(super) fn unpack( - expr: rq::Expr, - decl: FunctionDecl, -) -> [rq::Expr; ARG_COUNT] { - if let rq::ExprKind::BuiltInFunction { name, args } = expr.kind { - if decl.name == name { - return args.try_into().unwrap(); - } - } - unreachable!() -} - -/// Checks that the expr matches the passed built-in-function. -/// Returns an error if the matched function has wrong number of arguments. This can happen when -/// passing an invalid RQ representation. -pub(super) fn try_unpack( - expr: &rq::Expr, - decls: [FunctionDecl; X], -) -> Result, [&rq::Expr; ARG_COUNT])>> { - if let rq::ExprKind::BuiltInFunction { name, args } = &expr.kind { - for decl in decls { - if decl.name != name { - continue; - }; - - let args: [&rq::Expr; ARG_COUNT] = args - .iter() - .collect_vec() - .try_into() - .map_err(|_| anyhow::anyhow!("Bad usage of function {}", decl.name))?; - - return Ok(Some((decl, args))); - } - } - Ok(None) -} +// pub(crate) const STD_NEG: FunctionDecl<1> = FunctionDecl::new("std.neg"); +// pub(crate) const STD_NOT: FunctionDecl<1> = FunctionDecl::new("std.not"); diff --git a/prql-compiler/src/tests/test.rs b/prql-compiler/src/tests/test.rs index f4d2ff210884..c1ae740200d8 100644 --- a/prql-compiler/src/tests/test.rs +++ b/prql-compiler/src/tests/test.rs @@ -726,16 +726,14 @@ fn test_numbers() { #[test] fn test_ranges() { - let query = r###" + assert_display_snapshot!((compile(r###" from employees derive [ close = (distance | in 0..100), far = (distance | in 100..), country_founding | in @1776-07-04..@1787-09-17 ] - "###; - - assert_display_snapshot!((compile(query).unwrap()), @r###" + "###).unwrap()), @r###" SELECT *, distance BETWEEN 0 AND 100 AS close,