diff --git a/.gitignore b/.gitignore index fbb9182..1985ac9 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,6 @@ datasets/ # Playwright MCP scratch output (screenshots, console logs, downloads) .playwright-mcp/ + +# Full-run log from `sqlbench regen` (tee'd locally, not an artifact) +regen.log diff --git a/Cargo.toml b/Cargo.toml index f813457..5943d3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,18 @@ harness = false name = "sqlbench" path = "src/bin/sqlbench.rs" +[[bin]] +name = "build_sqlite_suite" +path = "src/bin/build_sqlite_suite.rs" + +[[bin]] +name = "build_proc_suites" +path = "src/bin/build_proc_suites.rs" + +[[bin]] +name = "repair_corpus" +path = "src/bin/repair_corpus.rs" + # Strip only DWARF debug info from release builds. The WASM viewer is built with # `dx build --web --release`; rustc's DWARF tripped wasm-opt ("unsupported # version of DWARF"), so removing it lets wasm-opt succeed and shrinks the wasm. diff --git a/README.md b/README.md index d9d9fa4..1c3a10f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Choosing a SQL parser for a Rust project means weighing dialect coverage, correc We evaluated nine parser libraries: [sqlparser-rs](https://github.com/sqlparser-rs/sqlparser-rs) (Apache DataFusion), [pg_query.rs](https://github.com/pganalyze/pg_query.rs) and its faster summary mode (Rust bindings to [libpg_query](https://github.com/pganalyze/libpg_query), PostgreSQL's own parser), [databend-common-ast](https://crates.io/crates/databend-common-ast), [polyglot-sql](https://github.com/tobilg/polyglot), [sqlglot-rust](https://crates.io/crates/sqlglot-rust), [qusql-parse](https://crates.io/crates/qusql-parse), [sqlite3-parser](https://crates.io/crates/sqlite3-parser) (lemon-rs), and [turso_parser](https://crates.io/crates/turso_parser) (the SQLite parser from Turso), plus [orql](https://codeberg.org/xitep/orql) on Oracle. We ran them against a corpus of 340,938 statements spanning 13 dialects, drawn from each engine's own regression suites and official samples and committed compressed so every run is reproducible. -We exercised each parser in the dialect that matches the corpus under test. Where a dialect has a runnable engine, we labelled each statement valid or invalid with the real database engine itself, run in Docker via [testcontainers](https://github.com/testcontainers/testcontainers-rs): a statement counts as valid unless the engine reports a syntax error, so a missing table or column still counts as parsed. Against that ground truth we scored the parsers on recall (valid statements accepted), false positives (invalid statements wrongly accepted), display round-trip stability, and canonical-form fidelity. The other dialects have no runnable engine, so their statements count as provenance-valid and the metric is simply the acceptance rate. Across all dialects, we captured speed as a per-statement parse-time distribution over every accepted statement, and memory as the peak and retained bytes per statement under a counting allocator. A batch axis additionally parses each parser's whole accepted set as a single script, showing what bulk parsing amortizes, and a time machine benchmarks the historical releases of every pure-Rust parser (59 versions in total, including every sqlparser-rs minor since January 2023), so each parser page also charts how coverage, speed, and memory evolved across releases. +We exercised each parser in the dialect that matches the corpus under test. Where a dialect has a runnable engine, we labelled each statement valid or invalid with the real database engine itself, run in Docker via [testcontainers](https://github.com/testcontainers/testcontainers-rs): a statement counts as valid unless the engine reports a syntax error, so a missing table or column still counts as parsed. Against that ground truth we scored the parsers on recall (valid statements accepted), false positives (invalid statements wrongly accepted), and display round-trip stability. The other dialects have no runnable engine, so their statements count as provenance-valid and the metric is simply the acceptance rate. Across all dialects, we captured speed as a per-statement parse-time distribution over every accepted statement, and memory as the peak and retained bytes per statement under a counting allocator. A batch axis additionally parses each parser's whole accepted set as a single script, showing what bulk parsing amortizes, and a time machine benchmarks the historical releases of every pure-Rust parser (59 versions in total, including every sqlparser-rs minor since January 2023), so each parser page also charts how coverage, speed, and memory evolved across releases. On their home dialect the reference bindings are exact by construction, so the more telling comparison is among the pure-Rust parsers. There, [sqlparser-rs](https://github.com/sqlparser-rs/sqlparser-rs) is the most broadly capable, the permissive parsers such as [polyglot-sql](https://github.com/tobilg/polyglot) accept the most statements but pay for it with a high false-positive rate, and the stricter parsers reject more in exchange for precision. Speed spans more than an order of magnitude, from well under a microsecond per statement for the fastest parsers to the low single-digit microseconds for most, with [polyglot-sql](https://github.com/tobilg/polyglot) a clear outlier at roughly fifteen. No parser leads on every axis, so the right choice comes down to what a given project values most: broad coverage, few false positives, or raw speed. @@ -38,7 +38,7 @@ Per-parser repository metadata (stars, contributors, fuzzing, test and benchmark 340,938 statements across 32 files and 13 dialects, committed compressed as `datasets.tar.zst` (5.6 MB) and unpacked to `datasets/{dialect}/{name}.txt`, one statement per line. The commands below extract it automatically on first use. All sources are openly licensed (Apache-2.0, MIT, BSD, public domain or CC-BY), drawn from each engine's own regression suites and official samples. The SQLite corpus includes the SQLite project's own official test suite (public domain), which exercises SQLite-specific grammar such as PRAGMAs, virtual tables, recursive CTEs, and upsert. Natural-language-with-embedded-SQL datasets are intentionally excluded. -Correctness is defined per dialect. Dialects with a runnable engine are graded against that real database engine, run in Docker via testcontainers by the `oracle` crate: a statement is valid unless the engine reports a syntax error (a missing table or column still counts as parsed). The validity labels are computed once and committed under `oracle/labels`, so grading and CI need no Docker. That reference splits the corpus into valid and invalid and scores recall, false positives, round-trip, and fidelity. Dialects with no runnable engine (cloud services, heavy JVM engines) have no reference, so their statements count as provenance-valid (sourced from each engine's own suites) and the metric is acceptance rate. Speed is a per-statement parse-time distribution over every accepted statement, timed with an adaptive iteration count on a no-`catch_unwind` path. Memory is measured separately with a counting allocator, as peak live bytes and retained (AST) bytes per statement. A companion batch axis parses each parser's whole accepted set as one script and normalizes the time and memory by the statement count, showing what bulk parsing amortizes against parsing one statement at a time. A batch that does not parse the whole set (a parser that bails out partway) is dropped rather than reported, and parsers without a multi-statement entry point (databend-common-ast) sit out the batch axis. +Correctness is defined per dialect. Dialects with a runnable engine are graded against that real database engine, run in Docker via testcontainers by the `oracle` crate: a statement is valid unless the engine reports a syntax error (a missing table or column still counts as parsed). The validity labels are computed once and committed under `oracle/labels`, so grading and CI need no Docker. That reference splits the corpus into valid and invalid and scores recall, false positives, and round-trip. Dialects with no runnable engine (cloud services, heavy JVM engines) have no reference, so their statements count as provenance-valid (sourced from each engine's own suites) and the metric is acceptance rate. Speed is a per-statement parse-time distribution over every accepted statement, timed with an adaptive iteration count on a no-`catch_unwind` path. Memory is measured separately with a counting allocator, as peak live bytes and retained (AST) bytes per statement. A companion batch axis parses each parser's whole accepted set as one script and normalizes the time and memory by the statement count, showing what bulk parsing amortizes against parsing one statement at a time. A batch that does not parse the whole set (a parser that bails out partway) is dropped rather than reported, and parsers without a multi-statement entry point (databend-common-ast) sit out the batch axis. ## Running diff --git a/benches/batch_parsing.rs b/benches/batch_parsing.rs index 68c6650..51a5383 100644 --- a/benches/batch_parsing.rs +++ b/benches/batch_parsing.rs @@ -1,45 +1,39 @@ -//! Multi-dialect BATCH (whole-script) parse-time benchmark over the full +//! Multi-dialect BATCH (multi-statement script) parse benchmark over the full //! `datasets/` corpus. //! //! Companion to `benches/parsing.rs`. Where `parsing` times each statement in -//! isolation, this concatenates every statement a parser accepts in a dialect -//! into one script and times parsing that whole script in a single call, then -//! divides by the statement count to get a normalized per-statement cost. The -//! contrast between this and the per-statement median isolates what a batch API -//! pays or amortizes, the effect raised in issue #15: `Parser::parse_sql` grows -//! a `Vec` of large `Statement` values, so bulk parsing can behave differently -//! from many single-statement calls. +//! isolation, this draws random fixed-size batches of statements a parser can +//! individually digest, joins each into one script, and parses it in a single +//! call. It reports two things per (parser, dialect): batch accuracy, the share +//! of batches that reparse to exactly the expected statement count, and the +//! per-statement parse time averaged over the batches that did. Sampling instead +//! of concatenating the whole accepted set keeps one statement that mishandles +//! the terminator (a real but narrow bug) from voiding the entire measurement +//! under the all-or-nothing `parse_sql`. //! -//! Both axes are measured over the SAME accepted set (statements the parser -//! parses in that dialect), so the two numbers are directly comparable. +//! The sampling, joining, and accuracy live in `sql_ast_benchmark::batch` so the +//! memory bench (`membench -- batch`) and the time machine sample identically. +//! Only parsers with a multi-statement entry point take part (`can_batch`). //! -//! Only parsers with a multi-statement entry point take part (see -//! `BenchParser::can_batch`). `databend-common-ast` parses one statement per -//! call and is simply skipped here. -//! -//! Output (under `target/batch_dist/`), self-contained for now (not yet wired -//! into the web export): -//! - `summary.csv` : per-pair statement count, statements the parser saw, -//! batch size in bytes, whole-script time, and time normalized per -//! statement. +//! Output (`target/batch_dist/summary.csv`): per pair the eligible count, the +//! number of batches, how many were correct, the accuracy percent, and the +//! per-statement time over correct batches. //! //! Full run: `cargo bench --bench batch_parsing` //! Smoke (default): `cargo test` or `cargo bench --bench batch_parsing -- --test` -//! -//! The full run unpacks `datasets.tar.zst` automatically if `datasets/` is -//! missing. The smoke path needs no corpus, so `cargo test` stays fast. -use sql_ast_benchmark::batch::join_batch; +use sql_ast_benchmark::batch::{evaluate_batches, reports_statement_count, BATCH_K, BATCH_M}; use sql_ast_benchmark::datasets::Dialect; use sql_ast_benchmark::report::load_dialect; use sql_ast_benchmark::BenchParser; use std::fs; use std::hint::black_box; use std::io::Write as _; +use std::panic::AssertUnwindSafe; use std::time::Instant; /// Deep statements can exhaust the default stack inside recursive-descent -/// parsers, and a stack overflow aborts the process, so time on a large stack. +/// parsers, and a stack overflow aborts the process, so run on a large stack. const WORKER_STACK: usize = 1024 * 1024 * 1024; const OUT_DIR: &str = "target/batch_dist"; @@ -60,14 +54,13 @@ const DIALECTS: &[Dialect] = &[ Dialect::Multi, ]; -/// Whole-script parse time (ns/batch): adaptive iteration count so a short -/// script still accumulates enough work per round, capped low because one batch -/// call already does a lot. Best (min) of `ROUNDS` rounds. -fn time_batch(mut f: impl FnMut() -> usize) -> f64 { - const TARGET_NS: u128 = 2_000_000; // aim for ~2 ms of work per round +/// Whole-sweep parse time (ns): adaptive iteration count so a short sweep still +/// accumulates enough work per round, best (min) of `ROUNDS` rounds. +fn time_sweep(mut f: impl FnMut() -> usize) -> f64 { + const TARGET_NS: u128 = 2_000_000; const ROUNDS: usize = 5; - black_box(f()); // warm up + black_box(f()); let probe = Instant::now(); black_box(f()); let single = probe.elapsed().as_nanos().max(1); @@ -85,56 +78,74 @@ fn time_batch(mut f: impl FnMut() -> usize) -> f64 { best } +/// Parse one script to a statement count, treating a caught panic as 0 so a +/// single pathological input does not abort the whole (parser, dialect) pair. +fn safe_count(parser: BenchParser, sql: &str, dialect: Dialect) -> usize { + std::panic::catch_unwind(AssertUnwindSafe(|| { + parser.parse_batch(sql, dialect).unwrap_or(0) + })) + .unwrap_or(0) +} + struct Row { dialect: &'static str, parser: &'static str, - /// Statements fed into the batch (the parser's accepted set). - n_accepted: usize, - /// Statements the parser reported parsing from the batch (coverage). - n_parsed: usize, - batch_bytes: usize, - /// Whole-script parse time (ns). - batch_ns: f64, - /// `batch_ns / n_accepted`: time per statement in batch context. - ns_per_stmt: f64, + n_eligible: usize, + k: usize, + n_correct: usize, + accuracy_pct: Option, + /// Per-statement parse time over the correct batches (ns), `None` when none. + ns_per_stmt: Option, } -/// Time one (parser, dialect) pair: build the accepted set, concatenate it into -/// one script, time the whole-script parse, and normalize per statement. +/// Evaluate one (parser, dialect) pair: build the eligible set, sample batches, +/// measure accuracy, and time the batches that parsed correctly. fn run_pair(parser: BenchParser, dialect: Dialect, stmts: &[String]) -> Row { - let accepted: Vec<&str> = stmts + // Eligible = accepted, parses to exactly one statement alone, and safe to + // batch (not COPY ... FROM STDIN). The single==1 check makes the expected + // per-batch count exactly the batch size. + let eligible: Vec<&str> = stmts .iter() - .filter(|s| parser.accepts(s, dialect) == Some(true)) + .filter(|s| { + parser.accepts(s, dialect) == Some(true) + && sql_ast_benchmark::batch::batch_eligible(s) + && safe_count(parser, s, dialect) == 1 + }) .map(String::as_str) .collect(); - let mut row = Row { + let label = format!("{}/{}", dialect.dir_name(), parser.name()); + let eval = evaluate_batches(&eligible, &label, |s| safe_count(parser, s, dialect)); + + let ns_per_stmt = if eval.n_correct == 0 { + None + } else { + let denom = (eval.n_correct * eval.effective_m) as f64; + let sweep = time_sweep(|| { + eval.correct_scripts + .iter() + .map(|s| safe_count(parser, s, dialect)) + .sum() + }); + Some(sweep / denom) + }; + + Row { dialect: dialect.dir_name(), parser: parser.name(), - n_accepted: accepted.len(), - n_parsed: 0, - batch_bytes: 0, - batch_ns: 0.0, - ns_per_stmt: 0.0, - }; - if accepted.is_empty() { - return row; + n_eligible: eval.n_eligible, + k: eval.k, + n_correct: eval.n_correct, + accuracy_pct: eval.accuracy_pct(), + ns_per_stmt, } - - let batch = join_batch(&accepted); - row.batch_bytes = batch.len(); - row.n_parsed = parser.parse_batch(&batch, dialect).unwrap_or(0); - row.batch_ns = time_batch(|| parser.parse_batch(&batch, dialect).unwrap_or(0)); - row.ns_per_stmt = row.batch_ns / accepted.len() as f64; - row } /// Quick smoke check used by `cargo test`: every batch-capable parser parses a -/// tiny multi-statement script per supported dialect without panicking. Needs -/// no corpus, so it stays instant. +/// tiny multi-statement script per supported dialect without panicking. fn smoke() { std::panic::set_hook(Box::new(|_| {})); - let script = "SELECT 1;\nSELECT 2;\nSELECT 3"; + let script = "SELECT 1\n;\nSELECT 2\n;\nSELECT 3"; for &dialect in DIALECTS { for parser in BenchParser::all() { if !parser.can_batch() || !parser.supports(dialect) { @@ -147,9 +158,6 @@ fn smoke() { } fn main() { - // Match `benches/parsing.rs`: only an explicit `cargo bench` (which passes - // `--bench` and not `--test`) does the full, datasets-backed run. `cargo - // test` and a bare run take the fast smoke path, which needs no corpus. let args: Vec = std::env::args().collect(); let full_run = args.iter().any(|a| a == "--bench") && !args.iter().any(|a| a == "--test"); if !full_run { @@ -157,8 +165,6 @@ fn main() { return; } - // Acceptance checks are panic-guarded. Suppress the default panic message so - // a caught panic does not spam stderr. std::panic::set_hook(Box::new(|_| {})); if let Err(e) = sql_ast_benchmark::datasets::ensure_corpus() { @@ -170,12 +176,13 @@ fn main() { let mut summary = fs::File::create(format!("{OUT_DIR}/summary.csv")).expect("summary.csv"); writeln!( summary, - "dialect,parser,n_accepted,n_parsed,batch_bytes,batch_ns,ns_per_stmt" + "dialect,parser,n_eligible,k,n_correct,accuracy_pct,ns_per_stmt" ) .unwrap(); let parsers = BenchParser::all(); let start_all = Instant::now(); + println!("batch sampling: m={BATCH_M} statements, k={BATCH_K} batches per pair"); for &dialect in DIALECTS { let stmts = load_dialect(dialect); @@ -187,9 +194,12 @@ fn main() { if !parser.can_batch() || !parser.supports(dialect) { continue; } + // Skip parsers whose batch entry point does not report a true + // statement count (e.g. pg_query summary returns distinct types). + if !reports_statement_count(|s| safe_count(parser, s, dialect)) { + continue; + } let job_start = Instant::now(); - // Run on a large stack: deeply nested accepted statements can - // otherwise overflow the default stack and abort the process. let result = std::thread::scope(|scope| { std::thread::Builder::new() .stack_size(WORKER_STACK) @@ -206,33 +216,31 @@ fn main() { continue; }; + let acc = row + .accuracy_pct + .map_or_else(String::new, |a| format!("{a:.3}")); + let ns = row + .ns_per_stmt + .map_or_else(String::new, |n| format!("{n:.1}")); writeln!( summary, - "{},{},{},{},{},{:.1},{:.1}", - row.dialect, - row.parser, - row.n_accepted, - row.n_parsed, - row.batch_bytes, - row.batch_ns, - row.ns_per_stmt, + "{},{},{},{},{},{acc},{ns}", + row.dialect, row.parser, row.n_eligible, row.k, row.n_correct, ) .unwrap(); summary.flush().unwrap(); - let coverage = if row.n_accepted == 0 { - 0.0 - } else { - 100.0 * row.n_parsed as f64 / row.n_accepted as f64 - }; println!( - "{:<11} {:<24} n={:>6} seen={:>6} ({:>3.0}%) batch={:>9.0}ns/stmt ({:.1}s)", + "{:<11} {:<24} elig={:>6} ok={:>3}/{:<3} acc={:>6} batch={:>9}ns/stmt ({:.1}s)", row.dialect, row.parser, - row.n_accepted, - row.n_parsed, - coverage, - row.ns_per_stmt, + row.n_eligible, + row.n_correct, + row.k, + row.accuracy_pct + .map_or_else(|| "n/a".to_string(), |a| format!("{a:.1}%")), + row.ns_per_stmt + .map_or_else(|| "n/a".to_string(), |n| format!("{n:.0}")), job_start.elapsed().as_secs_f64(), ); } diff --git a/datasets.tar.zst b/datasets.tar.zst index 2f8ce01..1de3a6e 100644 Binary files a/datasets.tar.zst and b/datasets.tar.zst differ diff --git a/membench/src/main.rs b/membench/src/main.rs index ab62f6e..4ba900d 100644 --- a/membench/src/main.rs +++ b/membench/src/main.rs @@ -26,7 +26,7 @@ use std::fs; use std::io::Write as _; use std::path::Path; -use sql_ast_benchmark::batch::join_batch; +use sql_ast_benchmark::batch::{batch_eligible, evaluate_batches, reports_statement_count}; use sql_ast_benchmark::datasets::{ensure_corpus, Dialect}; use sql_ast_benchmark::stats::slug; use sql_ast_benchmark::BenchParser; @@ -172,16 +172,27 @@ fn run() { } } -/// Whole-script (batch) memory: one (peak, retained) pair per (parser, dialect), -/// normalized per statement, written to a single summary file. Only parsers with -/// a batch entry point whose memory is visible to the Rust allocator take part. +/// Parse one script to a statement count under panic protection, so a single +/// pathological input cannot abort the whole batch run. +fn safe_count(parser: BenchParser, sql: &str, dialect: Dialect) -> usize { + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + parser.parse_batch(sql, dialect).unwrap_or(0) + })) + .unwrap_or(0) +} + +/// Batch memory: per (parser, dialect) it samples the same random batches as the +/// time bench (deterministic seed), measures peak/retained over the batches that +/// reparse correctly, and records them normalized per statement to +/// `target/batch_mem_dist/summary.csv`. Only parsers whose memory is visible to +/// the Rust allocator take part (the libpg_query bindings report `None`). fn run_batch() { fs::create_dir_all(BATCH_OUT_DIR).expect("create batch_mem_dist dir"); let mut summary = fs::File::create(format!("{BATCH_OUT_DIR}/summary.csv")).expect("create summary.csv"); writeln!( summary, - "dialect,parser,n_accepted,n_parsed,peak_bytes,retained_bytes,peak_per_stmt,retained_per_stmt" + "dialect,parser,n_eligible,k,n_correct,accuracy_pct,peak_per_stmt,retained_per_stmt" ) .expect("write header"); @@ -194,46 +205,66 @@ fn run_batch() { if !parser.can_batch() || !parser.supports(dialect) { continue; } - let accepted: Vec<&str> = stmts - .iter() - .filter(|s| parser.accepts(s, dialect) == Some(true)) - .map(String::as_str) - .collect(); - if accepted.is_empty() { + // Skip parsers whose memory is invisible to the Rust allocator (the + // libpg_query bindings parse in C and report None). + if parser.measure_mem_batch("SELECT 1", dialect).is_none() { continue; } - let batch = join_batch(&accepted); - // Warm up: let one-time caches/lazy statics allocate first, so they - // raise the baseline rather than this measurement. Also skips - // parsers whose memory is invisible to the Rust allocator (None). - if parser.measure_mem_batch(&batch, dialect).is_none() { + // Skip parsers whose batch entry point does not report a true count. + if !reports_statement_count(|s| safe_count(parser, s, dialect)) { continue; } - let Some((peak, retained)) = parser.measure_mem_batch(&batch, dialect) else { - continue; + let eligible: Vec<&str> = stmts + .iter() + .filter(|s| { + parser.accepts(s, dialect) == Some(true) + && batch_eligible(s) + && safe_count(parser, s, dialect) == 1 + }) + .map(String::as_str) + .collect(); + let label = format!("{}/{}", dialect.dir_name(), parser.name()); + let eval = evaluate_batches(&eligible, &label, |s| safe_count(parser, s, dialect)); + + let (peak_per_stmt, retained_per_stmt) = if eval.n_correct == 0 { + (String::new(), String::new()) + } else { + let mut peak_sum = 0u128; + let mut ret_sum = 0u128; + for s in &eval.correct_scripts { + if let Some((peak, retained)) = parser.measure_mem_batch(s, dialect) { + peak_sum += peak as u128; + ret_sum += retained as u128; + } + } + let denom = (eval.n_correct * eval.effective_m) as f64; + ( + format!("{:.1}", peak_sum as f64 / denom), + format!("{:.1}", ret_sum as f64 / denom), + ) }; - // Statements the parser actually consumed from the script, so the - // export can drop a pair whose batch parse bailed out early. - let n_parsed = parser.parse_batch(&batch, dialect).unwrap_or(0); - let n = accepted.len() as f64; + let acc = eval + .accuracy_pct() + .map_or_else(String::new, |a| format!("{a:.3}")); writeln!( summary, - "{},{},{},{n_parsed},{peak},{retained},{:.1},{:.1}", + "{},{},{},{},{},{acc},{peak_per_stmt},{retained_per_stmt}", dialect.dir_name(), parser.name(), - accepted.len(), - peak as f64 / n, - retained as f64 / n, + eval.n_eligible, + eval.k, + eval.n_correct, ) .expect("write row"); summary.flush().expect("flush summary"); - let coverage = 100.0 * n_parsed as f64 / n; eprintln!( - "batch-mem {} {}: n={} seen={n_parsed} ({coverage:.0}%) peak={peak} retained={retained}", + "batch-mem {} {}: elig={} ok={}/{} peak/stmt={peak_per_stmt} ret/stmt={retained_per_stmt}", dialect.dir_name(), parser.name(), - accepted.len(), + eval.n_eligible, + eval.n_correct, + eval.k, ); } } diff --git a/oracle/labels/clickhouse.tsv.zst b/oracle/labels/clickhouse.tsv.zst index 730a01b..42b489f 100644 Binary files a/oracle/labels/clickhouse.tsv.zst and b/oracle/labels/clickhouse.tsv.zst differ diff --git a/oracle/labels/duckdb.tsv.zst b/oracle/labels/duckdb.tsv.zst index fb26d46..a5d9e80 100644 Binary files a/oracle/labels/duckdb.tsv.zst and b/oracle/labels/duckdb.tsv.zst differ diff --git a/oracle/labels/mysql.tsv.zst b/oracle/labels/mysql.tsv.zst index 4aac84e..74c149a 100644 Binary files a/oracle/labels/mysql.tsv.zst and b/oracle/labels/mysql.tsv.zst differ diff --git a/oracle/labels/postgresql.tsv.zst b/oracle/labels/postgresql.tsv.zst index 11e9d47..847ecc9 100644 Binary files a/oracle/labels/postgresql.tsv.zst and b/oracle/labels/postgresql.tsv.zst differ diff --git a/oracle/labels/sqlite.tsv.zst b/oracle/labels/sqlite.tsv.zst index ad69657..7c89cf0 100644 Binary files a/oracle/labels/sqlite.tsv.zst and b/oracle/labels/sqlite.tsv.zst differ diff --git a/oracle/labels/tsql.tsv.zst b/oracle/labels/tsql.tsv.zst index edd6fcb..1ff35a9 100644 Binary files a/oracle/labels/tsql.tsv.zst and b/oracle/labels/tsql.tsv.zst differ diff --git a/oracle/src/main.rs b/oracle/src/main.rs index 88305c6..ea2b3ea 100644 --- a/oracle/src/main.rs +++ b/oracle/src/main.rs @@ -117,32 +117,114 @@ async fn label_postgresql(stmts: &[String]) -> Result> { let port = node.get_host_port_ipv4(5432).await?; let conn_str = format!("host={host} port={port} user=postgres password=postgres dbname=postgres"); - let (client, connection) = tokio_postgres::connect(&conn_str, NoTls) - .await - .context("connect postgres")?; - tokio::spawn(async move { - let _ = connection.await; - }); let mut valid = Vec::with_capacity(stmts.len()); - for (i, s) in stmts.iter().enumerate() { - // Make sure no aborted transaction is left from a prior error. - let _ = client.batch_execute("ROLLBACK").await; - let _ = client.batch_execute("BEGIN").await; - let res = client.batch_execute(s).await; - let _ = client.batch_execute("ROLLBACK").await; - let v = match res { - Ok(()) => true, - Err(e) => e.code() != Some(&SqlState::SYNTAX_ERROR), - }; - valid.push(v); - if i % 2000 == 0 { - eprintln!(" postgresql {i}/{}", stmts.len()); + let mut reconnects = 0usize; + // A statement that terminates the backend twice at the same index is a + // confirmed "poison" (some pg_regress statements crash/kill the connection); + // it is marked invalid and skipped, mirroring the ClickHouse handling. + let mut death_idx: Option = None; + let mut death_count = 0usize; + + 'session: while valid.len() < stmts.len() { + let (client, connection) = tokio_postgres::connect(&conn_str, NoTls) + .await + .context("connect postgres")?; + tokio::spawn(async move { + let _ = connection.await; + }); + + let mut unreachable_streak = 0usize; + while valid.len() < stmts.len() { + let i = valid.len(); + // Make sure no aborted transaction is left from a prior error. + let _ = client.batch_execute("ROLLBACK").await; + let _ = client.batch_execute("BEGIN").await; + let res = client.batch_execute(&stmts[i]).await; + let _ = client.batch_execute("ROLLBACK").await; + // A verdict only counts if the server actually answered: an error with + // a SQLSTATE is a real result (syntax -> invalid, anything else parsed + // -> valid). An error with no code is a transport/connection failure, + // which must never be recorded as "valid". + let verdict = match res { + Ok(()) => Some(true), + Err(e) => e.code().map(|code| code != &SqlState::SYNTAX_ERROR), + }; + match verdict { + Some(v) => { + unreachable_streak = 0; + valid.push(v); + if i.is_multiple_of(2000) { + eprintln!(" postgresql {i}/{}", stmts.len()); + } + } + None if is_copy_to_stdout(&stmts[i]) => { + // `COPY ... TO STDOUT` parses fine but then streams rows over + // the COPY sub-protocol, which `batch_execute` cannot consume, + // so it breaks the connection with no SQLSTATE. Reaching that + // stage proves it parsed (a syntax error would carry code + // 42601 and be a real verdict above), so it is valid. Record it + // and reconnect to replace the now-broken connection. + valid.push(true); + death_idx = None; + death_count = 0; + reconnects += 1; + anyhow::ensure!( + reconnects <= 50, + "postgres reconnected {reconnects} times (last at statement {i}); aborting without writing a label cache" + ); + continue 'session; + } + None if unreachable_streak + 1 < 6 => { + unreachable_streak += 1; + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + } + None => { + // Connection is gone: the backend died (often killed by the + // statement itself). Reconnect and resume; if the same index + // kills it twice, treat that statement as poison. + if death_idx == Some(i) { + death_count += 1; + } else { + death_idx = Some(i); + death_count = 1; + } + if death_count >= 2 { + eprintln!( + " postgresql: statement {i} repeatedly kills the backend; marking invalid and skipping: {}", + stmts[i].chars().take(120).collect::() + ); + valid.push(false); + death_idx = None; + death_count = 0; + } else { + eprintln!( + " postgresql backend died at {i}/{}; reconnecting", + stmts.len() + ); + } + reconnects += 1; + anyhow::ensure!( + reconnects <= 50, + "postgres backend crashed {reconnects} times (last at statement {i}); aborting without writing a label cache" + ); + continue 'session; + } + } } } Ok(valid) } +/// Whether a statement is `COPY ... TO STDOUT`: valid SQL whose result is streamed +/// over the COPY sub-protocol, which the simple-query probe cannot consume (it +/// breaks the connection with no SQLSTATE). A syntactically invalid COPY instead +/// returns a real syntax error, so this only matches genuinely-valid ones. +fn is_copy_to_stdout(stmt: &str) -> bool { + let up = stmt.trim_start().to_ascii_uppercase(); + up.starts_with("COPY") && up.contains("TO STDOUT") +} + /// MySQL: real server in a container. We use `PREPARE`, MySQL's parse-only path: /// it parses (and name-resolves) without executing, so there are no side effects /// and nothing blocks. Invalid iff `PREPARE` fails with error 1064 @@ -164,23 +246,43 @@ async fn label_mysql(stmts: &[String]) -> Result> { let mut conn = pool.get_conn().await.context("connect mysql")?; let mut valid = Vec::with_capacity(stmts.len()); - for (i, s) in stmts.iter().enumerate() { - let stmt = s.trim().trim_end_matches(';'); + let mut unreachable_streak = 0usize; + let mut i = 0; + while i < stmts.len() { + let stmt = stmts[i].trim().trim_end_matches(';'); // Bind the statement text as a parameter (no injection), then PREPARE it. - let v = match conn.exec_drop("SET @q = ?", (stmt,)).await { + // Only a `Server` response is a real verdict (error 1064 = syntax -> + // invalid, any other server error parsed -> valid). A non-server error is + // a transport/connection failure and must never be recorded as "valid". + let verdict = match conn.exec_drop("SET @q = ?", (stmt,)).await { Ok(()) => match conn.query_drop("PREPARE _ck FROM @q").await { Ok(()) => { let _ = conn.query_drop("DEALLOCATE PREPARE _ck").await; - true + Some(true) } - Err(mysql_async::Error::Server(e)) => e.code != 1064, - Err(_) => true, + Err(mysql_async::Error::Server(e)) => Some(e.code != 1064), + Err(_) => None, }, - Err(_) => true, + Err(mysql_async::Error::Server(e)) => Some(e.code != 1064), + Err(_) => None, }; - valid.push(v); - if i % 2000 == 0 { - eprintln!(" mysql {i}/{}", stmts.len()); + match verdict { + Some(v) => { + unreachable_streak = 0; + valid.push(v); + i += 1; + if i.is_multiple_of(2000) { + eprintln!(" mysql {i}/{}", stmts.len()); + } + } + None => { + unreachable_streak += 1; + anyhow::ensure!( + unreachable_streak < 10, + "mysql became unreachable at statement {i}; aborting without writing a label cache" + ); + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } } } drop(conn); @@ -192,44 +294,153 @@ async fn label_mysql(stmts: &[String]) -> Result> { /// parses only (no execution, no tables needed). Invalid iff the exception code /// is 62 (SYNTAX_ERROR). Any other code (unknown table/identifier, not /// implemented) means it parsed, so it is valid. +/// +/// Hardened on two fronts: +/// +/// * Correctness: every response body is fully consumed before the connection is +/// reused (the undrained error body was what desynced later responses and +/// silently mislabeled statements valid), transport blips are retried, and a +/// result that cannot be classified is never assumed valid. +/// * Resilience: the pinned ClickHouse image segfaults nondeterministically under +/// the sustained full-corpus load. When the engine stops responding the +/// container is restarted and labeling resumes from the same statement (each +/// `EXPLAIN AST` is independent, so a fresh engine yields identical verdicts). +/// A restart cap stops an unrecoverable engine from looping forever. async fn label_clickhouse(stmts: &[String]) -> Result> { use testcontainers_modules::clickhouse::ClickHouse; use testcontainers_modules::testcontainers::runners::AsyncRunner; - let node = ClickHouse::default() - .start() - .await - .context("start clickhouse container")?; - let host = node.get_host().await?; - let port = node.get_host_port_ipv4(8123).await?; - let url = format!("http://{host}:{port}/"); - let client = reqwest::Client::new(); + let mut valid: Vec = Vec::with_capacity(stmts.len()); + let mut restarts = 0usize; + let mut poisoned: Vec = Vec::new(); + // Track repeated deaths at one index: a statement that crashes the engine twice + // in a row is a confirmed parser-crash ("poison") and is skipped as invalid. + let mut death_idx: Option = None; + let mut death_count = 0usize; - let mut valid = Vec::with_capacity(stmts.len()); - for (i, s) in stmts.iter().enumerate() { - let query = format!("EXPLAIN AST {}", s.trim().trim_end_matches(';')); - let v = match client.post(&url).body(query).send().await { - Ok(resp) if resp.status().is_success() => true, + 'engine: while valid.len() < stmts.len() { + let node = ClickHouse::default() + .start() + .await + .context("start clickhouse container")?; + let host = node.get_host().await?; + let port = node.get_host_port_ipv4(8123).await?; + let url = format!("http://{host}:{port}/"); + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .context("build clickhouse http client")?; + + let mut consecutive_unreachable = 0usize; + while valid.len() < stmts.len() { + let i = valid.len(); + let query = format!("EXPLAIN AST {}", stmts[i].trim().trim_end_matches(';')); + match clickhouse_classify(&client, &url, &query).await { + Some(v) => { + consecutive_unreachable = 0; + valid.push(v); + if i.is_multiple_of(5000) { + eprintln!(" clickhouse {i}/{}", stmts.len()); + } + } + None if consecutive_unreachable + 1 < 6 => { + // A transient blip: wait and retry the SAME statement (do not + // advance, do not guess a verdict). + consecutive_unreachable += 1; + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } + None => { + // Engine unreachable: it has crashed. Was the crash provoked by + // this exact statement (it died here last restart too)? + if death_idx == Some(i) { + death_count += 1; + } else { + death_idx = Some(i); + death_count = 1; + } + drop(node); + if death_count >= 2 { + eprintln!( + " clickhouse: statement {i} repeatedly crashes the engine; marking invalid and skipping: {}", + stmts[i].chars().take(120).collect::() + ); + valid.push(false); + poisoned.push(i); + death_idx = None; + death_count = 0; + } else { + eprintln!( + " clickhouse unreachable at {i}/{}; restarting engine", + stmts.len() + ); + } + restarts += 1; + anyhow::ensure!( + restarts <= 50, + "ClickHouse crashed {restarts} times (last at statement {i}); aborting without writing a label cache" + ); + continue 'engine; + } + } + } + } + if !poisoned.is_empty() { + eprintln!( + " clickhouse: {} statement(s) crashed the engine and were marked invalid (indices: {:?})", + poisoned.len(), + poisoned + ); + } + Ok(valid) +} + +/// Classify one ClickHouse `EXPLAIN AST` request, retrying transient transport +/// failures. The response body is always fully read before returning, so a +/// connection is never left mid-stream (the bug that desynced reused connections). +/// `Some(true)` if the request succeeded (2xx) or failed with a non-syntax +/// exception code (the statement parsed, the engine just could not resolve or +/// execute it); `Some(false)` for exception code 62 (`SYNTAX_ERROR`) or an +/// unclassifiable response; `None` if the engine was unreachable after retries. +async fn clickhouse_classify(client: &reqwest::Client, url: &str, query: &str) -> Option { + for attempt in 0..3 { + match client.post(url).body(query.to_string()).send().await { Ok(resp) => { - let code = resp + let success = resp.status().is_success(); + let header_code = resp .headers() .get("x-clickhouse-exception-code") .and_then(|h| h.to_str().ok()) .and_then(|s| s.parse::().ok()); - match code { + let body = resp.text().await.unwrap_or_default(); + if success { + return Some(true); + } + return Some(match header_code.or_else(|| parse_clickhouse_code(&body)) { Some(62) => false, Some(_) => true, - None => !resp.text().await.unwrap_or_default().contains("Code: 62."), - } + None => false, + }); } - Err(_) => true, - }; - valid.push(v); - if i % 5000 == 0 { - eprintln!(" clickhouse {i}/{}", stmts.len()); + Err(_) if attempt < 2 => { + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + } + Err(_) => return None, } } - Ok(valid) + None +} + +/// Parse the leading exception code from a ClickHouse error body, e.g. +/// `"Code: 62. DB::Exception: ..."` -> `Some(62)`. +fn parse_clickhouse_code(body: &str) -> Option { + let digits: String = body + .trim_start() + .strip_prefix("Code:")? + .trim_start() + .chars() + .take_while(char::is_ascii_digit) + .collect(); + digits.parse().ok() } /// SQL Server (T-SQL): real server in a container. `SET PARSEONLY ON` parses @@ -273,14 +484,38 @@ async fn label_tsql(stmts: &[String]) -> Result> { .await?; let mut valid = Vec::with_capacity(stmts.len()); - for (i, s) in stmts.iter().enumerate() { - let v = match client.simple_query(s.as_str()).await { - Ok(stream) => stream.into_results().await.is_ok(), - Err(_) => false, + let mut unreachable_streak = 0usize; + let mut i = 0; + while i < stmts.len() { + // Under PARSEONLY the only `Server` error is a syntax error (invalid). A + // non-server error is a transport/connection failure: never record it as a + // verdict, retry, and abort if the engine stays unreachable. + let verdict = match client.simple_query(stmts[i].as_str()).await { + Ok(stream) => match stream.into_results().await { + Ok(_) => Some(true), + Err(tiberius::error::Error::Server(_)) => Some(false), + Err(_) => None, + }, + Err(tiberius::error::Error::Server(_)) => Some(false), + Err(_) => None, }; - valid.push(v); - if i % 2000 == 0 { - eprintln!(" tsql {i}/{}", stmts.len()); + match verdict { + Some(v) => { + unreachable_streak = 0; + valid.push(v); + i += 1; + if i.is_multiple_of(2000) { + eprintln!(" tsql {i}/{}", stmts.len()); + } + } + None => { + unreachable_streak += 1; + anyhow::ensure!( + unreachable_streak < 10, + "sql server became unreachable at statement {i}; aborting without writing a label cache" + ); + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } } } Ok(valid) @@ -346,6 +581,18 @@ fn label_sqlite(stmts: &[String]) -> Result> { }); let out = child.wait_with_output().context("sqlite3 wait")?; let _ = writer.join(); + + // sqlite3 normally exits 0 (clean) or 1 (some statement errored, `.bail off` + // keeps going). A crash or container failure surfaces as the container exit + // code >= 128 (128 + signal, e.g. 139 = SIGSEGV) or a docker error (125-127), + // or no code at all. In those cases the script stopped early, so the unscanned + // tail would silently default to "valid" -- abort instead of writing garbage. + match out.status.code() { + Some(0 | 1) => {} + other => anyhow::bail!( + "sqlite3 ended abnormally (exit {other:?}); a statement likely crashed the CLI. Aborting without writing a label cache" + ), + } let stderr = String::from_utf8_lossy(&out.stderr); let mut valid = vec![true; stmts.len()]; @@ -391,7 +638,35 @@ fn is_sqlite_invalid(msg: &str) -> bool { #[cfg(test)] mod tests { - use super::{is_sqlite_invalid, parse_sqlite_err}; + use super::{is_copy_to_stdout, is_sqlite_invalid, parse_clickhouse_code, parse_sqlite_err}; + + #[test] + fn copy_to_stdout_is_recognized() { + assert!(is_copy_to_stdout("COPY (SELECT 1) TO STDOUT")); + assert!(is_copy_to_stdout("copy (select 1) to stdout")); + assert!(is_copy_to_stdout("COPY (SELECT 1) TO STDOUT WITH CSV")); + assert!(is_copy_to_stdout(" COPY t TO STDOUT")); + // Not COPY-to-stdout: a real syntax verdict handles these, or they differ. + assert!(!is_copy_to_stdout("SELECT 'COPY x TO STDOUT'")); + assert!(!is_copy_to_stdout("COPY t FROM STDIN")); + assert!(!is_copy_to_stdout("SELECT 1")); + } + + #[test] + fn parse_clickhouse_code_reads_leading_code() { + assert_eq!( + parse_clickhouse_code("Code: 62. DB::Exception: Syntax error: ..."), + Some(62) + ); + assert_eq!( + parse_clickhouse_code("Code: 47. DB::Exception: Unknown identifier"), + Some(47) + ); + assert_eq!(parse_clickhouse_code(" Code: 999. foo"), Some(999)); + // No parseable code -> None, which the caller treats as invalid. + assert_eq!(parse_clickhouse_code("totally unexpected body"), None); + assert_eq!(parse_clickhouse_code(""), None); + } #[test] fn missing_object_errors_are_valid() { diff --git a/sqlparser-create-user-terminator-bug.md b/sqlparser-create-user-terminator-bug.md new file mode 100644 index 0000000..a53da3f --- /dev/null +++ b/sqlparser-create-user-terminator-bug.md @@ -0,0 +1,144 @@ +# `CREATE USER` and `ALTER USER ... SET` consume the statement terminator, breaking any following statement + +## Summary + +In sqlparser, `CREATE USER ` (and `ALTER USER SET ...`) parse correctly on their own, but when one is followed by another statement in the same script the parse fails. The shared helper `parse_key_value_options`, used to read the trailing option list, consumes the `;` terminator. The top-level statement loop then no longer sees a separator before the next statement and returns `Expected: end of statement, found: `, pointing at the first token after the semicolon. + +The defect is in `parse_key_value_options` itself, so it affects every statement that ends by calling it in unparenthesized mode. In 0.62.0 there are three such call sites: `CREATE USER`, `ALTER USER ... SET `, and `ALTER USER ... SET TAG ...`. The bug is dialect independent (`GenericDialect`, `MySqlDialect`, `PostgreSqlDialect`, and `SnowflakeDialect` all behave identically). Statements that do not reach the helper (for example `CREATE ROLE`, `DROP USER`, and `ALTER USER ... RENAME`) are unaffected. + +## Affected versions + +Reproduced on `sqlparser` 0.62.0 (crates.io) and on current `main`. + +## Reproduction + +`Cargo.toml`: + +```toml +[dependencies] +sqlparser = "0.62.0" +``` + +`src/main.rs`: + +```rust +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::Parser; + +fn check(sql: &str) { + match Parser::parse_sql(&GenericDialect {}, sql) { + Ok(v) => println!("{sql:<46} -> Ok({} statements)", v.len()), + Err(e) => println!("{sql:<46} -> {e}"), + } +} + +fn main() { + // Affected: each ends in an unparenthesized key-value option list. + check("CREATE USER user1; SELECT 1"); + check("ALTER USER user1 SET x = 'y'; SELECT 1"); + check("ALTER USER user1 SET TAG t = 'v'; SELECT 1"); + + // Fine on their own (the terminator is followed by EOF). + check("CREATE USER user1"); + check("ALTER USER user1 SET x = 'y'"); + + // Unaffected: never reach parse_key_value_options. + check("SELECT 1; CREATE USER user1"); + check("CREATE ROLE role1; SELECT 1"); + check("DROP USER user1; SELECT 1"); + check("ALTER USER user1 RENAME TO user2; SELECT 1"); + check("SELECT 1; SELECT 2"); +} +``` + +## Observed behavior + +```text +CREATE USER user1; SELECT 1 -> Expected: end of statement, found: SELECT at Line: 1, Column: 20 +ALTER USER user1 SET x = 'y'; SELECT 1 -> Expected: end of statement, found: SELECT at Line: 1, Column: 31 +ALTER USER user1 SET TAG t = 'v'; SELECT 1 -> Expected: end of statement, found: SELECT at Line: 1, Column: 35 +CREATE USER user1 -> Ok(1 statements) +ALTER USER user1 SET x = 'y' -> Ok(1 statements) +SELECT 1; CREATE USER user1 -> Ok(2 statements) +CREATE ROLE role1; SELECT 1 -> Ok(2 statements) +DROP USER user1; SELECT 1 -> Ok(2 statements) +ALTER USER user1 RENAME TO user2; SELECT 1 -> Ok(2 statements) +SELECT 1; SELECT 2 -> Ok(2 statements) +``` + +The first three inputs fail. Each affected statement parses alone, and the following statement parses alone, yet the two together fail. The affected statement even works when it is the last statement (`SELECT 1; CREATE USER user1` is `Ok(2)`), because then the terminator is followed by EOF and nothing is left to mis-parse. The reported column is always the position of the token immediately after the `;`, which shows the terminator has already been consumed by the time the error is raised. + +## Expected behavior + +`CREATE USER user1; SELECT 1` (and the two `ALTER USER ... SET` forms) should parse as two statements, the same way `CREATE ROLE role1; SELECT 1` and `SELECT 1; SELECT 2` do. + +## Root cause + +`parse_key_value_options` (src/parser/mod.rs, around line 20449) drives its loop with `self.next_token()`, which advances past the token it returns. Its terminator arm (around line 20468) breaks on a semicolon that has already been consumed: + +```rust +loop { + match self.next_token().token { + // ... + Token::EOF | Token::SemiColon => break, // the ';' is consumed, then we break + // ... + } +} +``` + +So when the option list is unparenthesized and ends at a `;`, the `;` is eaten and discarded. Control returns to the top-level statement loop, which expects a `;` separator (or EOF) before the next statement. Because the separator is gone, it sees the next statement's first token directly and fails with `Expected: end of statement, found: `. + +The three unparenthesized call sites in 0.62.0 are: + +- `parse_create_user`, src/parser/mod.rs around line 5224: `self.parse_key_value_options(false, &[Keyword::WITH, Keyword::TAG])`. +- `parse_alter_user`, src/parser/alter.rs around line 262 (the `SET TAG` branch): `self.parse_key_value_options(false, &[])`. +- `parse_alter_user`, src/parser/alter.rs around line 280 (the `SET ` branch): `self.parse_key_value_options(false, &[])`. + +This explains every case above: + +- `CREATE USER user1` alone: the loop reads `EOF` and breaks. Fine. +- `CREATE USER user1; SELECT 1`: the loop consumes `;` and breaks, then the top level sees `SELECT` with no preceding separator. Error. +- `SELECT 1; CREATE USER user1`: `SELECT` is parsed normally and does not eat its trailing `;`, then `CREATE USER` is parsed last and ends at `EOF`. Fine. +- `CREATE ROLE`, `DROP USER`, `ALTER USER ... RENAME`: they do not route through `parse_key_value_options`, so they terminate correctly. + +The parenthesized callers (`parse_key_value_options(true, ...)`, used by the Snowflake `FILE_FORMAT`, `COPY`, and similar option lists) are unaffected, because they end on `)` rather than on the statement terminator. + +## Suggested fix + +Do not consume the terminator. Put the semicolon back before breaking so the caller and the top-level statement loop can see it, for example: + +```rust +loop { + match self.next_token().token { + // ... + Token::EOF => break, + Token::SemiColon => { + self.prev_token(); + break; + } + // ... + } +} +``` + +(The `EOF` case needs no `prev_token`.) Peeking instead of consuming would work as well. This mirrors how the `end_words` arm already calls `self.prev_token()` before breaking. Fixing the single helper repairs all three statements at once. + +## Impact + +Any multi-statement script that contains `CREATE USER` or `ALTER USER ... SET` in a non-final position fails to parse in full. This was found while benchmarking sqlparser on whole-script (multi-statement) parsing of real-world SQL corpora, where a single such statement voids the entire script because `parse_sql` is all-or-nothing. + +## Suggested regression test + +```rust +#[test] +fn key_value_option_statements_do_not_swallow_following_statement() { + for sql in [ + "CREATE USER user1; SELECT 1", + "ALTER USER user1 SET x = 'y'; SELECT 1", + "ALTER USER user1 SET TAG t = 'v'; SELECT 1", + ] { + let stmts = Parser::parse_sql(&GenericDialect {}, sql).unwrap(); + assert_eq!(stmts.len(), 2, "{sql}"); + } +} +``` diff --git a/src/batch.rs b/src/batch.rs index 4a169ed..a784662 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -1,51 +1,287 @@ -//! Shared construction of a multi-statement script for the batch benchmarks. +//! Shared construction and sampling for the batch benchmarks. //! -//! Both the batch time bench (`benches/batch_parsing.rs`) and the batch memory -//! bench (`membench -- batch`) must feed parsers byte-identical input, so the -//! join lives here in one place rather than in each binary. +//! The batch axis measures how a parser handles a multi-statement script. Rather +//! than concatenating a parser's whole accepted set (where one statement that +//! mishandles the terminator makes the all-or-nothing `parse_sql` return zero and +//! voids the entire measurement), we draw `BATCH_K` random batches of `BATCH_M` +//! statements from the set the parser can individually digest, parse each as one +//! script, and report the share that reparse to the exact expected count plus the +//! time and memory over the batches that did. The time bench +//! (`benches/batch_parsing.rs`), the memory bench (`membench -- batch`), and the +//! time machine (`timemachine`) all use the helpers here so they sample and join +//! identically. -/// Join accepted statements into a single multi-statement script. +/// Statements per sampled batch. +pub const BATCH_M: usize = 128; + +/// Number of sampled batches per (parser, dialect). +pub const BATCH_K: usize = 200; + +/// Base seed for the deterministic sampler. Mixed per (parser, dialect) so each +/// pair samples reproducibly but distinctly. +pub const BATCH_SEED: u64 = 0x5108_5A17_B47C_0DE5; + +/// A three-distinct-statement probe to check a parser reports a true count. +/// +/// A parser whose batch entry point returns something other than 3 here (for +/// example `pg_query` summary mode, which returns the number of distinct +/// statement types) cannot be scored on batch accuracy and is left out of the +/// batch axis. +pub const COUNT_PROBE: &str = "SELECT 1\n;\nSELECT 2\n;\nSELECT 3"; + +/// Whether `count` (a parser's whole-script statement count) reports a true +/// statement count, checked against [`COUNT_PROBE`]. +pub fn reports_statement_count(mut count: impl FnMut(&str) -> usize) -> bool { + count(COUNT_PROBE) == 3 +} + +/// Join statements into a single multi-statement script. /// -/// Each corpus statement is one line, so a `;`-and-newline separator yields an -/// unambiguous script. A trailing `;` on a statement is stripped first to avoid -/// an empty statement between terminators. The last statement gets no terminator -/// (none is required at end of input). +/// The separator is a newline, then the `;` terminator, then a newline. The +/// leading newline is essential: a corpus statement is a single line and may end +/// in a `--` (or `#`) line comment, which runs to end of line, so a terminator +/// placed on the same line would be swallowed by that comment and silently merge +/// two statements into one. Putting the terminator on its own line closes any +/// trailing line comment first. A trailing `;` is stripped to avoid an empty +/// statement between terminators, and the last statement gets no terminator. #[must_use] -pub fn join_batch(accepted: &[&str]) -> String { - let mut out = String::with_capacity(accepted.iter().map(|s| s.len() + 2).sum()); - for (i, s) in accepted.iter().enumerate() { +pub fn join_batch(stmts: &[&str]) -> String { + let mut out = String::with_capacity(stmts.iter().map(|s| s.len() + 3).sum()); + for (i, s) in stmts.iter().enumerate() { if i > 0 { - out.push_str(";\n"); + out.push_str("\n;\n"); } out.push_str(s.trim().trim_end_matches(';').trim_end()); } out } +/// Whether a statement is safe to place in a concatenated batch script. +/// +/// `COPY ... FROM STDIN` reads the lines that follow it as inline data until a +/// `\.` terminator, so in a single script it swallows every statement after it. +/// It parses fine on its own, so it stays in the per-statement benchmarks, but it +/// must be excluded from the batch. Statements are single-line, so a token scan +/// is enough. +#[must_use] +pub fn batch_eligible(stmt: &str) -> bool { + let toks: Vec = stmt + .split_whitespace() + .map(str::to_ascii_lowercase) + .collect(); + let is_copy_from_stdin = toks.iter().any(|t| t == "copy") + && toks.windows(2).any(|w| w[0] == "from" && w[1] == "stdin"); + !is_copy_from_stdin +} + +/// Deterministic `SplitMix64`. Used to sample batches reproducibly without +/// pulling in an RNG dependency (the rest of the benchmark is deterministic). +struct SplitMix64(u64); + +impl SplitMix64 { + const fn new(seed: u64) -> Self { + Self(seed) + } + + const fn next_u64(&mut self) -> u64 { + self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = self.0; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) + } + + /// Uniform-ish index in `0..n` (n > 0). Modulo bias is negligible here. + const fn below(&mut self, n: usize) -> usize { + (self.next_u64() % n as u64) as usize + } +} + +/// A reproducible per-pair seed derived from the base seed and a label. +#[must_use] +pub fn seed_for(label: &str) -> u64 { + let mut h = BATCH_SEED; + for &b in label.as_bytes() { + h ^= u64::from(b); + h = h.wrapping_mul(0x0000_0100_0000_01b3); + } + h +} + +/// Sample `k` batches of distinct indices from `0..n`. +/// +/// Each batch holds `min(m, n)` distinct indices (partial Fisher-Yates), batches +/// may overlap, and the result is deterministic for a given `seed`. Returns an +/// empty vec when `n == 0`. +#[must_use] +pub fn sample_batches(n: usize, m: usize, k: usize, seed: u64) -> Vec> { + if n == 0 { + return Vec::new(); + } + let take = m.min(n); + let mut rng = SplitMix64::new(seed); + let mut pool: Vec = (0..n).collect(); + let mut out = Vec::with_capacity(k); + for _ in 0..k { + // Partial Fisher-Yates: swap `take` random picks to the front, then read. + for i in 0..take { + let j = i + rng.below(n - i); + pool.swap(i, j); + } + out.push(pool[..take].to_vec()); + } + out +} + +/// Result of measuring a parser on `k` sampled batches. +pub struct BatchEval { + /// Statements eligible for batching (accepted, single, not input-consuming). + pub n_eligible: usize, + /// Distinct statements per batch actually used (`min(BATCH_M, n_eligible)`). + pub effective_m: usize, + /// Number of batches attempted. + pub k: usize, + /// Batches that reparsed to exactly `effective_m` statements. + pub n_correct: usize, + /// The joined scripts of the correct batches, for timing or memory probing. + pub correct_scripts: Vec, +} + +impl BatchEval { + /// Accuracy as a percentage, or `None` when nothing was eligible. + #[must_use] + pub fn accuracy_pct(&self) -> Option { + (self.k > 0).then(|| 100.0 * self.n_correct as f64 / self.k as f64) + } +} + +/// Sample batches from `eligible` and find those that reparse to the full count. +/// +/// Draws `BATCH_K` batches of `BATCH_M` (seeded reproducibly by `label`), joins +/// each, and uses `count` (the parser's whole-script statement count) to keep the +/// batches that reparse to exactly `effective_m`. `eligible` must already be +/// filtered to statements the parser accepts, that parse to exactly one statement +/// alone, and that satisfy [`batch_eligible`]. +pub fn evaluate_batches( + eligible: &[&str], + label: &str, + mut count: impl FnMut(&str) -> usize, +) -> BatchEval { + let n = eligible.len(); + let batches = sample_batches(n, BATCH_M, BATCH_K, seed_for(label)); + let effective_m = BATCH_M.min(n); + let mut correct_scripts = Vec::new(); + for idxs in &batches { + let stmts: Vec<&str> = idxs.iter().map(|&i| eligible[i]).collect(); + let script = join_batch(&stmts); + if count(&script) == effective_m { + correct_scripts.push(script); + } + } + BatchEval { + n_eligible: n, + effective_m, + k: batches.len(), + n_correct: correct_scripts.len(), + correct_scripts, + } +} + #[cfg(test)] mod tests { - use super::join_batch; + use super::*; #[test] - fn joins_with_terminators_and_strips_trailing_semicolons() { + fn joins_with_terminator_on_its_own_line() { assert_eq!( join_batch(&["SELECT 1;", "SELECT 2"]), - "SELECT 1;\nSELECT 2" + "SELECT 1\n;\nSELECT 2" ); - // Already-terminated and whitespace-padded statements normalize cleanly. assert_eq!( join_batch(&[" SELECT 1 ; ", "SELECT 2 ;"]), - "SELECT 1;\nSELECT 2" + "SELECT 1\n;\nSELECT 2" ); } + #[test] + fn terminator_survives_a_trailing_line_comment() { + let joined = join_batch(&["SELECT 1 -- note", "SELECT 2"]); + assert_eq!(joined, "SELECT 1 -- note\n;\nSELECT 2"); + assert!(joined.contains("\n;\n")); + } + #[test] fn single_statement_has_no_terminator() { assert_eq!(join_batch(&["SELECT 1"]), "SELECT 1"); + assert_eq!(join_batch(&[]), ""); } #[test] - fn empty_input_is_empty() { - assert_eq!(join_batch(&[]), ""); + fn copy_from_stdin_is_excluded() { + assert!(!batch_eligible("COPY t FROM STDIN")); + assert!(!batch_eligible("copy t from stdin null 'x'")); + assert!(batch_eligible("SELECT 1")); + assert!(batch_eligible("INSERT INTO t SELECT * FROM other")); + } + + #[test] + fn sampler_is_deterministic_distinct_and_sized() { + let a = sample_batches(1000, 128, 200, 42); + let b = sample_batches(1000, 128, 200, 42); + assert_eq!(a, b, "same seed gives same batches"); + assert_ne!( + a, + sample_batches(1000, 128, 200, 43), + "seed changes batches" + ); + assert_eq!(a.len(), 200); + for batch in &a { + assert_eq!(batch.len(), 128); + let mut sorted = batch.clone(); + sorted.sort_unstable(); + sorted.dedup(); + assert_eq!(sorted.len(), 128, "indices within a batch are distinct"); + assert!(batch.iter().all(|&i| i < 1000)); + } + } + + #[test] + fn sampler_handles_small_and_empty_pools() { + assert!(sample_batches(0, 128, 200, 1).is_empty()); + let small = sample_batches(10, 128, 5, 1); + assert_eq!(small.len(), 5); + for batch in &small { + assert_eq!(batch.len(), 10, "effective_m caps at the pool size"); + } + } + + #[test] + fn accuracy_drops_when_a_swallower_is_present() { + // A toy "parser": counts ';'-separated parts, but a statement that begins + // with SWALLOW eats the rest of the script (returns 1). Mirrors how a real + // terminator bug collapses the count. + let count = |script: &str| { + if script.contains("SWALLOW") { + 1 + } else { + script.split("\n;\n").count() + } + }; + let mut clean: Vec<&str> = Vec::new(); + let owned: Vec = (0..500).map(|i| format!("SELECT {i}")).collect(); + for s in &owned { + clean.push(s); + } + let ok = evaluate_batches(&clean, "clean", count); + assert_eq!(ok.accuracy_pct(), Some(100.0)); + + let mut withbug = clean.clone(); + withbug.push("SWALLOW"); + let bug = evaluate_batches(&withbug, "bug", count); + let acc = bug.accuracy_pct().unwrap(); + assert!( + acc > 0.0 && acc < 100.0, + "accuracy {acc} should be between 0 and 100" + ); } } diff --git a/src/bin/build_proc_suites.rs b/src/bin/build_proc_suites.rs new file mode 100644 index 0000000..2708b57 --- /dev/null +++ b/src/bin/build_proc_suites.rs @@ -0,0 +1,549 @@ +//! Rebuild the Spark SQL and Oracle corpus files from their original sources, +//! keeping compound statements (`BEGIN ... END`, PL/SQL blocks) intact. The +//! original extractor split on every `;`, shredding Spark SQL scripting blocks +//! and Oracle PL/SQL blocks into invalid fragments (issue #22, provenance side). +//! +//! Spark source: apache/spark `sql/core/src/test/resources/sql-tests/inputs`. +//! Spark's own harness wraps any statement that contains inner `;` (the scripting +//! `BEGIN ... END` blocks) in `--QUERY-DELIMITER-START` / `--QUERY-DELIMITER-END` +//! markers, so we honor those: text between a marker pair is one statement, +//! everything else splits on `;`. +//! +//! Oracle source: oracle-samples/db-sample-schemas. These are SQL*Plus scripts: +//! a PL/SQL block (`DECLARE`/`BEGIN`/`CREATE ... PROCEDURE|FUNCTION|PACKAGE| +//! TRIGGER|TYPE`) runs until a line containing only `/`; every other statement +//! ends at `;`. +//! +//! cargo run --release --bin build_proc_suites -- +//! +//! Then repack `datasets.tar.zst` (Spark and Oracle are provenance, no oracle). + +#![allow( + clippy::doc_markdown, + clippy::too_many_lines, + clippy::items_after_statements +)] + +use std::collections::HashSet; +use std::fs; +use std::path::{Path, PathBuf}; + +/// Collapse a raw statement to one trimmed line (drops comments already removed). +fn normalize(s: &str) -> String { + s.split_whitespace().collect::>().join(" ") +} + +/// Copy a quoted literal verbatim from `chars[i..]` into `buf`, returning the +/// index just past the closing quote. Handles `'`, `"`, backtick (doubling +/// escape) and `[` (closed by `]`, no escape). +fn copy_quote(chars: &[char], mut i: usize, buf: &mut String) -> usize { + let open = chars[i]; + let close = if open == '[' { ']' } else { open }; + buf.push(open); + i += 1; + while i < chars.len() { + let d = chars[i]; + if d == close { + if close != ']' && chars.get(i + 1) == Some(&close) { + buf.push(d); + buf.push(d); + i += 2; + continue; + } + buf.push(d); + return i + 1; + } + buf.push(d); + i += 1; + } + i +} + +/// Split Spark golden-test SQL into statements, honoring `--QUERY-DELIMITER` +/// regions (one statement each) and otherwise splitting on top-level `;`. Lines +/// that are pure directive comments (`--CONFIG`, `--SET`, `--IMPORT`, ...) are +/// dropped; trailing `--` and `/* */` comments are stripped. +fn split_spark(input: &str) -> Vec { + let mut out = Vec::new(); + let mut buf = String::new(); + let mut region = false; + + for raw_line in input.lines() { + let trimmed = raw_line.trim_start(); + if trimmed.starts_with("--QUERY-DELIMITER-START") { + region = true; + continue; + } + if trimmed.starts_with("--QUERY-DELIMITER-END") { + let s = normalize(&buf); + if !s.is_empty() { + out.push(s); + } + buf.clear(); + region = false; + continue; + } + if region { + // Whole region is one statement; keep code, drop full-line comments. + if !trimmed.starts_with("--") { + strip_line_into(raw_line, &mut buf, &mut Vec::new(), true); + buf.push(' '); + } + continue; + } + if trimmed.starts_with("--") { + continue; // directive / comment line + } + // Normal line: split on `;`, stripping inline comments and quotes. + strip_line_into(raw_line, &mut buf, &mut out, false); + buf.push(' '); + } + let s = normalize(&buf); + if !s.is_empty() { + out.push(s); + } + out +} + +/// Append `line` to `buf`, stripping comments and copying quotes verbatim. When +/// `region_only` is false, a top-level `;` flushes `buf` (normalized) into `out`. +fn strip_line_into(line: &str, buf: &mut String, out: &mut Vec, region_only: bool) { + let chars: Vec = line.chars().collect(); + let mut i = 0; + while i < chars.len() { + let c = chars[i]; + if c == '-' && chars.get(i + 1) == Some(&'-') { + break; // rest of line is a comment + } + if c == '/' && chars.get(i + 1) == Some(&'*') { + i += 2; + while i < chars.len() && !(chars[i] == '*' && chars.get(i + 1) == Some(&'/')) { + i += 1; + } + i += 2; + continue; + } + if matches!(c, '\'' | '"' | '`' | '[') { + i = copy_quote(&chars, i, buf); + continue; + } + if c == ';' && !region_only { + let s = normalize(buf); + if !s.is_empty() { + out.push(s); + } + buf.clear(); + i += 1; + continue; + } + buf.push(c); + i += 1; + } +} + +/// Split a comment-free string on top-level `;`, respecting quoted literals. +fn split_semicolons(s: &str) -> Vec { + let chars: Vec = s.chars().collect(); + let mut out = Vec::new(); + let mut buf = String::new(); + let mut i = 0; + while i < chars.len() { + let c = chars[i]; + if matches!(c, '\'' | '"' | '`' | '[') { + i = copy_quote(&chars, i, &mut buf); + continue; + } + if c == ';' { + out.push(std::mem::take(&mut buf)); + i += 1; + continue; + } + buf.push(c); + i += 1; + } + out.push(buf); + out +} + +/// Harvest the standalone DML statements from inside a PL/SQL block, so the bulk +/// `INSERT`/`UPDATE`/... that the block wraps remain individual corpus entries. A +/// leading `BEGIN` glued to the first inner statement is stripped. Non-DML pieces +/// (declarations, control flow, BEGIN/END) are dropped. +fn harvest_dml(block: &str) -> Vec { + let mut out = Vec::new(); + for piece in split_semicolons(block) { + let mut p = normalize(&piece); + if let Some(rest) = p + .strip_prefix("BEGIN ") + .or_else(|| p.strip_prefix("begin ")) + { + p = rest.trim().to_string(); + } + let first = p + .split_whitespace() + .next() + .unwrap_or("") + .to_ascii_uppercase(); + if matches!( + first.as_str(), + "INSERT" | "UPDATE" | "DELETE" | "SELECT" | "MERGE" | "WITH" + ) { + out.push(p); + } + } + out +} + +/// Split Oracle SQL*Plus script text into `(normal, special)`: normal per-statement +/// corpus entries, and special whole PL/SQL anonymous blocks (kept once, isolated +/// from the per-statement metrics). A `/` line ends a block; `;` ends other +/// statements. Anonymous `DECLARE`/`BEGIN` blocks go to `special`, and their inner +/// DML is also harvested into `normal`; `CREATE ... PROCEDURE/...` blocks are kept +/// whole in `normal` (real DDL statements). +fn split_oracle(input: &str) -> (Vec, Vec) { + let mut normal = Vec::new(); + let mut special = Vec::new(); + let mut buf = String::new(); + let mut in_block = false; + let mut anon = false; + let mut started = false; + + for raw_line in input.lines() { + let trimmed = raw_line.trim(); + // SQL*Plus block terminator: end the current PL/SQL block. + if trimmed == "/" { + let s = normalize(&buf); + if !s.is_empty() { + if anon { + special.push(s); + normal.extend(harvest_dml(&buf)); + } else { + normal.push(s); + } + } + buf.clear(); + in_block = false; + anon = false; + started = false; + continue; + } + // Skip pure comment lines and SQL*Plus client directives (REM, PROMPT, + // SET, ACCEPT, etc.) when no statement is in progress. These are not SQL + // and, left in the buffer, would also set `started` and mask a following + // `BEGIN` block opener (the ACCEPT ... HIDE / BEGIN IF ... pattern). + if buf.trim().is_empty() { + let up = trimmed.to_ascii_uppercase(); + // Leading SQL*Plus command words (skip the whole line when one starts it). + const DIRECTIVES: &[&str] = &[ + "PROMPT", + "SET ", + "DEFINE", + "UNDEFINE", + "SPOOL", + "WHENEVER", + "CONNECT", + "ALTER SESSION", + "COLUMN ", + "ACCEPT ", + "PAUSE", + "EXEC ", + "EXECUTE ", + "VARIABLE ", + "VAR ", + "PRINT ", + "SHOW ", + "BREAK", + "COMPUTE ", + "TTITLE", + "BTITLE", + "STORE ", + "SAVE ", + "HOST", + "CLEAR ", + "TIMING", + "START ", + "ACCEPT", + ]; + if trimmed.is_empty() + || trimmed.starts_with("--") + || trimmed.starts_with('@') + || up.starts_with("REM ") + || up == "REM" + || DIRECTIVES.iter().any(|d| up.starts_with(d)) + { + continue; + } + } + + let chars: Vec = raw_line.chars().collect(); + let mut i = 0; + while i < chars.len() { + let c = chars[i]; + if c == '-' && chars.get(i + 1) == Some(&'-') { + break; + } + if c == '/' && chars.get(i + 1) == Some(&'*') { + i += 2; + while i < chars.len() && !(chars[i] == '*' && chars.get(i + 1) == Some(&'/')) { + i += 1; + } + i += 2; + continue; + } + if matches!(c, '\'' | '"' | '`' | '[') { + i = copy_quote(&chars, i, &mut buf); + started = true; + continue; + } + if (c.is_alphanumeric() || c == '_') && !started && c.is_alphabetic() { + // Detect the leading keyword to decide block vs simple. + let mut j = i; + while j < chars.len() && (chars[j].is_alphanumeric() || chars[j] == '_') { + j += 1; + } + let w: String = chars[i..j].iter().collect::().to_ascii_uppercase(); + if w == "DECLARE" || w == "BEGIN" { + in_block = true; + anon = true; + } + started = true; + // fall through to copy chars normally below + } + // Once inside a CREATE statement, promote to block on a body keyword. + if c.is_alphabetic() { + let mut j = i; + while j < chars.len() && (chars[j].is_alphanumeric() || chars[j] == '_') { + j += 1; + } + let w: String = chars[i..j].iter().collect::().to_ascii_uppercase(); + if matches!( + w.as_str(), + "PROCEDURE" | "FUNCTION" | "PACKAGE" | "TRIGGER" | "TYPE" + ) && buf.to_ascii_uppercase().trim_start().starts_with("CREATE") + { + in_block = true; + } + buf.push_str(&chars[i..j].iter().collect::()); + i = j; + continue; + } + if c == ';' && !in_block { + let s = normalize(&buf); + if !s.is_empty() { + normal.push(s); + } + buf.clear(); + started = false; + i += 1; + continue; + } + buf.push(c); + i += 1; + } + buf.push(' '); + } + let s = normalize(&buf); + if !s.is_empty() { + if anon { + special.push(s); + normal.extend(harvest_dml(&buf)); + } else { + normal.push(s); + } + } + (normal, special) +} + +fn sql_files(dir: &Path) -> Vec { + let mut out = Vec::new(); + let mut stack = vec![dir.to_path_buf()]; + while let Some(d) = stack.pop() { + let Ok(entries) = fs::read_dir(&d) else { + continue; + }; + for e in entries.flatten() { + let p = e.path(); + if p.is_dir() { + stack.push(p); + } else if p.extension().is_some_and(|x| x == "sql") { + out.push(p); + } + } + } + out.sort(); + out +} + +/// Load the lines of `datasets//` into `seen` (for cross-file dedup). +fn seed_seen(seen: &mut HashSet, rel: &str) { + if let Ok(c) = fs::read_to_string(Path::new("datasets").join(rel)) { + for l in c.lines() { + if !l.trim().is_empty() { + seen.insert(l.trim().to_string()); + } + } + } +} + +fn build_spark(src: &Path) { + let mut seen = HashSet::new(); + seed_seen(&mut seen, "spark_sql/clickbench_spark.txt"); + seed_seen(&mut seen, "spark_sql/databricks_perf.txt"); + let mut kept = Vec::new(); + let mut total = 0usize; + for f in sql_files(src) { + for s in split_spark(&fs::read_to_string(&f).unwrap_or_default()) { + total += 1; + if seen.insert(s.clone()) { + kept.push(s); + } + } + } + fs::write( + "datasets/spark_sql/spark_sql_tst.txt", + format!("{}\n", kept.join("\n")), + ) + .expect("write spark corpus"); + println!( + "spark_sql: {total} parsed, {} kept. wrote datasets/spark_sql/spark_sql_tst.txt", + kept.len() + ); +} + +fn build_oracle(src: &Path) { + let mut seen = HashSet::new(); + seed_seen(&mut seen, "oracle/oracle_examples.txt"); + let mut normal_kept = Vec::new(); + let mut special_seen = HashSet::new(); + let mut special_kept = Vec::new(); + let (mut n_total, mut s_total) = (0usize, 0usize); + for f in sql_files(src) { + let (normal, special) = split_oracle(&fs::read_to_string(&f).unwrap_or_default()); + for s in normal { + n_total += 1; + if seen.insert(s.clone()) { + normal_kept.push(s); + } + } + for s in special { + s_total += 1; + if special_seen.insert(s.clone()) { + special_kept.push(s); + } + } + } + fs::write( + "datasets/oracle/oracle_schemas.txt", + format!("{}\n", normal_kept.join("\n")), + ) + .expect("write oracle corpus"); + // Special PL/SQL blocks live outside any dialect directory, so the + // per-statement benchmark never loads them (they would be huge outliers); they + // are kept once as whole-block test cases. + fs::create_dir_all("datasets/special").expect("create datasets/special"); + fs::write( + "datasets/special/oracle_plsql_blocks.txt", + format!("{}\n", special_kept.join("\n")), + ) + .expect("write oracle blocks"); + println!( + "oracle: {n_total} normal parsed, {} kept; {s_total} blocks, {} special kept (datasets/special/oracle_plsql_blocks.txt)", + normal_kept.len(), + special_kept.len(), + ); +} + +fn main() { + if let Some(s) = std::env::args().nth(1) { + build_spark(Path::new(&s)); + } + if let Some(o) = std::env::args().nth(2) { + build_oracle(Path::new(&o)); + } +} + +#[cfg(test)] +mod tests { + use super::{split_oracle, split_spark}; + + #[test] + fn spark_region_is_one_statement() { + let sql = "SELECT 1;\n--QUERY-DELIMITER-START\nBEGIN\n DECLARE x INT;\n SET x = 1;\nEND;\n--QUERY-DELIMITER-END\nSELECT 2;"; + assert_eq!( + split_spark(sql), + vec![ + "SELECT 1".to_string(), + "BEGIN DECLARE x INT; SET x = 1; END;".to_string(), + "SELECT 2".to_string(), + ] + ); + } + + #[test] + fn spark_strips_directives_and_comments() { + let sql = "--CONFIG dim\n--SET spark.x=1\nSELECT 1 -- trailing\n;\nSET spark.y = 2;"; + assert_eq!( + split_spark(sql), + vec!["SELECT 1".to_string(), "SET spark.y = 2".to_string()] + ); + } + + #[test] + fn oracle_anon_block_is_special_and_inner_dml_harvested() { + // The whole anonymous block is kept once as a special entry, and its inner + // INSERTs also become individual normal corpus statements (in order). + let sql = "INSERT INTO t VALUES (1);\nBEGIN\n INSERT INTO t VALUES (2);\n INSERT INTO t VALUES (3);\nEND;\n/\nINSERT INTO t VALUES (4);"; + let (normal, special) = split_oracle(sql); + assert_eq!( + normal, + vec![ + "INSERT INTO t VALUES (1)".to_string(), + "INSERT INTO t VALUES (2)".to_string(), + "INSERT INTO t VALUES (3)".to_string(), + "INSERT INTO t VALUES (4)".to_string(), + ] + ); + assert_eq!( + special, + vec!["BEGIN INSERT INTO t VALUES (2); INSERT INTO t VALUES (3); END;".to_string()] + ); + } + + #[test] + fn oracle_declare_block_keeps_inner_semicolons() { + // Declarations and assignments are not DML, so only the INSERT is harvested. + let sql = "DECLARE v NUMBER;\nBEGIN\n v := 1;\n INSERT INTO t VALUES (v);\nEND;\n/"; + let (normal, special) = split_oracle(sql); + assert_eq!(normal, vec!["INSERT INTO t VALUES (v)".to_string()]); + assert_eq!( + special, + vec!["DECLARE v NUMBER; BEGIN v := 1; INSERT INTO t VALUES (v); END;".to_string()] + ); + } + + #[test] + fn oracle_plain_statements_split_on_semicolon() { + let sql = "CREATE TABLE t (a NUMBER);\nINSERT INTO t VALUES (1);"; + let (normal, special) = split_oracle(sql); + assert_eq!( + normal, + vec![ + "CREATE TABLE t (a NUMBER)".to_string(), + "INSERT INTO t VALUES (1)".to_string(), + ] + ); + assert!(special.is_empty()); + } + + #[test] + fn oracle_create_procedure_block_stays_whole_in_normal() { + // A CREATE PROCEDURE block is real DDL: kept whole, in normal, not special. + let sql = "CREATE PROCEDURE p IS\nBEGIN\n INSERT INTO t VALUES (1);\nEND;\n/"; + let (normal, special) = split_oracle(sql); + assert_eq!( + normal, + vec!["CREATE PROCEDURE p IS BEGIN INSERT INTO t VALUES (1); END;".to_string()] + ); + assert!(special.is_empty()); + } +} diff --git a/src/bin/build_sqlite_suite.rs b/src/bin/build_sqlite_suite.rs new file mode 100644 index 0000000..d2d3764 --- /dev/null +++ b/src/bin/build_sqlite_suite.rs @@ -0,0 +1,354 @@ +//! Rebuild `datasets/sqlite/sqlite_official_suite.txt` from the original SQLite +//! official test suite, with a SQLite-aware statement splitter that keeps +//! compound `CREATE TRIGGER ... BEGIN ...; ... END` statements intact. +//! +//! The corpus is one statement per line. The original extractor (removed from the +//! repo) split on every `;`, which shredded trigger bodies on their inner +//! semicolons and produced invalid fragments (issue #22). This rebuilds the suite +//! correctly: it splits only on top-level `;` (outside string/identifier quotes, +//! comments, `BEGIN ... END` trigger bodies, and `CASE ... END`), normalizes each +//! statement to one line, strips comments, and dedupes within the suite and +//! against the other committed SQLite corpus files. +//! +//! Source: the SQLite project's own tests, public domain, as bundled in +//! codeschool/sqlite-parser under `test/sql/official-suite/*.sql`. Clone that repo +//! and pass the directory: +//! +//! git clone --depth 1 https://github.com/codeschool/sqlite-parser /tmp/sp +//! cargo run --release --bin build_sqlite_suite -- /tmp/sp/test/sql/official-suite +//! +//! Then repack (`tar --zstd -cf datasets.tar.zst datasets`) and re-run the SQLite +//! oracle (`cargo run --release -p oracle -- sqlite`). + +#![allow( + clippy::doc_markdown, + clippy::too_many_lines, + clippy::items_after_statements +)] + +use std::collections::HashSet; +use std::fs; +use std::path::Path; + +/// Split raw SQLite script text into normalized one-line statements. +/// +/// Splits on top-level `;` only: semicolons inside single/double/backtick/bracket +/// quotes, `--` and block comments, a `CREATE TRIGGER` `BEGIN ... END` body, or a +/// `CASE ... END` are not statement terminators. Each statement is normalized to a +/// single line (whitespace runs collapsed) with comments removed. +#[must_use] +fn split_sql(input: &str) -> Vec { + let mut out = Vec::new(); + let mut buf = String::new(); + let mut word = String::new(); + let mut case_depth = 0usize; + let mut block_depth = 0usize; + let mut is_trigger = false; + + // Apply a completed word's effect on block/case tracking. + fn classify( + word: &mut String, + case_depth: &mut usize, + block_depth: &mut usize, + is_trigger: &mut bool, + ) { + if word.is_empty() { + return; + } + match word.to_ascii_uppercase().as_str() { + "TRIGGER" => *is_trigger = true, + "CASE" => *case_depth += 1, + "END" => { + if *case_depth > 0 { + *case_depth -= 1; + } else if *block_depth > 0 { + *block_depth -= 1; + } + } + // The only BEGIN inside a CREATE TRIGGER is the body opener. A bare + // BEGIN (transaction) is not a trigger, so it does not open a block. + "BEGIN" if *is_trigger => *block_depth += 1, + _ => {} + } + word.clear(); + } + + // Push a single normalizing space (collapse runs, skip leading). + fn push_space(buf: &mut String) { + if !buf.is_empty() && !buf.ends_with(' ') { + buf.push(' '); + } + } + + let end_statement = |buf: &mut String, + out: &mut Vec, + case_depth: &mut usize, + block_depth: &mut usize, + is_trigger: &mut bool| { + let s = buf.trim().to_string(); + if !s.is_empty() { + // Final pass: collapse any whitespace that survived inside quoted + // literals so the statement is one line (string contents do not + // affect parse benchmarking). + let normalized = s.split_whitespace().collect::>().join(" "); + out.push(normalized); + } + buf.clear(); + *case_depth = 0; + *block_depth = 0; + *is_trigger = false; + }; + + let chars: Vec = input.chars().collect(); + let mut i = 0; + while i < chars.len() { + let c = chars[i]; + + // Comments: strip to a single space. + if c == '-' && chars.get(i + 1) == Some(&'-') { + classify( + &mut word, + &mut case_depth, + &mut block_depth, + &mut is_trigger, + ); + while i < chars.len() && chars[i] != '\n' { + i += 1; + } + push_space(&mut buf); + continue; + } + if c == '/' && chars.get(i + 1) == Some(&'*') { + classify( + &mut word, + &mut case_depth, + &mut block_depth, + &mut is_trigger, + ); + i += 2; + while i < chars.len() && !(chars[i] == '*' && chars.get(i + 1) == Some(&'/')) { + i += 1; + } + i += 2; + push_space(&mut buf); + continue; + } + + // Quoted string / identifier: copy verbatim, honoring doubling escapes. + if matches!(c, '\'' | '"' | '`' | '[') { + classify( + &mut word, + &mut case_depth, + &mut block_depth, + &mut is_trigger, + ); + let close = if c == '[' { ']' } else { c }; + buf.push(c); + i += 1; + loop { + if i >= chars.len() { + break; + } + let d = chars[i]; + if d == close { + // Doubling escape ('' "" ``) keeps the quote open. Brackets + // have no escape in SQLite. + if close != ']' && chars.get(i + 1) == Some(&close) { + buf.push(d); + buf.push(d); + i += 2; + continue; + } + buf.push(d); + i += 1; + break; + } + buf.push(d); + i += 1; + } + continue; + } + + if c.is_alphanumeric() || c == '_' { + word.push(c); + buf.push(c); + i += 1; + continue; + } + + // Non-word character: settle the pending word first. + classify( + &mut word, + &mut case_depth, + &mut block_depth, + &mut is_trigger, + ); + + if c == ';' && case_depth == 0 && block_depth == 0 { + end_statement( + &mut buf, + &mut out, + &mut case_depth, + &mut block_depth, + &mut is_trigger, + ); + i += 1; + continue; + } + + if c.is_whitespace() { + push_space(&mut buf); + } else { + buf.push(c); + } + i += 1; + } + classify( + &mut word, + &mut case_depth, + &mut block_depth, + &mut is_trigger, + ); + end_statement( + &mut buf, + &mut out, + &mut case_depth, + &mut block_depth, + &mut is_trigger, + ); + out +} + +fn main() { + let src = std::env::args().nth(1).unwrap_or_else(|| { + eprintln!("usage: build_sqlite_suite "); + std::process::exit(2); + }); + let src = Path::new(&src); + + // Statements already in the other committed SQLite corpus files, to dedupe + // against (keep the suite from duplicating Spider / sql-create-context). + let mut seen: HashSet = HashSet::new(); + for other in ["spider_sqlite.txt", "sql_create_ctx.txt"] { + let p = Path::new("datasets/sqlite").join(other); + if let Ok(content) = fs::read_to_string(&p) { + for line in content.lines() { + let l = line.trim(); + if !l.is_empty() { + seen.insert(l.to_string()); + } + } + } + } + let existing = seen.len(); + + let mut files: Vec<_> = fs::read_dir(src) + .expect("read official-suite dir") + .filter_map(Result::ok) + .map(|e| e.path()) + .filter(|p| p.extension().is_some_and(|x| x == "sql")) + .collect(); + files.sort(); + + let mut out_lines: Vec = Vec::new(); + let mut total = 0usize; + for f in &files { + let content = fs::read_to_string(f).expect("read sql file"); + for stmt in split_sql(&content) { + total += 1; + if seen.insert(stmt.clone()) { + out_lines.push(stmt); + } + } + } + + let dest = Path::new("datasets/sqlite/sqlite_official_suite.txt"); + fs::write(dest, format!("{}\n", out_lines.join("\n"))).expect("write suite"); + println!( + "{} source files, {total} statements parsed, {} kept after dedup ({} were dupes of the existing {existing} SQLite statements or each other).", + files.len(), + out_lines.len(), + total - out_lines.len(), + ); + println!("wrote {}", dest.display()); +} + +#[cfg(test)] +mod tests { + use super::split_sql; + + #[test] + fn keeps_trigger_body_intact() { + let sql = "CREATE TRIGGER r1 AFTER INSERT ON t2 BEGIN\n SELECT 'hello';\nEND;\nSELECT 1;"; + assert_eq!( + split_sql(sql), + vec![ + "CREATE TRIGGER r1 AFTER INSERT ON t2 BEGIN SELECT 'hello'; END".to_string(), + "SELECT 1".to_string(), + ] + ); + } + + #[test] + fn multi_statement_trigger_body_stays_one_statement() { + let sql = "CREATE TRIGGER t AFTER UPDATE ON x BEGIN UPDATE a SET b=1; DELETE FROM c; END; DROP TABLE x;"; + assert_eq!( + split_sql(sql), + vec![ + "CREATE TRIGGER t AFTER UPDATE ON x BEGIN UPDATE a SET b=1; DELETE FROM c; END" + .to_string(), + "DROP TABLE x".to_string(), + ] + ); + } + + #[test] + fn leading_semicolons_and_newlines() { + // The suite often puts the terminator at the start of the next line. + let sql = "CREATE TABLE abc(a, b, c)\n;ALTER TABLE abc ADD d INTEGER\n;SELECT 1\n"; + assert_eq!( + split_sql(sql), + vec![ + "CREATE TABLE abc(a, b, c)".to_string(), + "ALTER TABLE abc ADD d INTEGER".to_string(), + "SELECT 1".to_string(), + ] + ); + } + + #[test] + fn semicolons_in_strings_and_comments_do_not_split() { + let sql = "SELECT ';' AS x -- ; not a split\n; SELECT /* ; */ 2;"; + assert_eq!( + split_sql(sql), + vec!["SELECT ';' AS x".to_string(), "SELECT 2".to_string()] + ); + } + + #[test] + fn case_end_does_not_close_a_trigger() { + let sql = + "CREATE TRIGGER t AFTER INSERT ON x BEGIN SELECT CASE WHEN 1 THEN 2 ELSE 3 END; END; SELECT 9;"; + assert_eq!( + split_sql(sql), + vec![ + "CREATE TRIGGER t AFTER INSERT ON x BEGIN SELECT CASE WHEN 1 THEN 2 ELSE 3 END; END" + .to_string(), + "SELECT 9".to_string(), + ] + ); + } + + #[test] + fn bare_begin_transaction_is_its_own_statement() { + let sql = "BEGIN; INSERT INTO t VALUES(1); COMMIT;"; + assert_eq!( + split_sql(sql), + vec![ + "BEGIN".to_string(), + "INSERT INTO t VALUES(1)".to_string(), + "COMMIT".to_string(), + ] + ); + } +} diff --git a/src/bin/repair_corpus.rs b/src/bin/repair_corpus.rs new file mode 100644 index 0000000..e0b889e --- /dev/null +++ b/src/bin/repair_corpus.rs @@ -0,0 +1,387 @@ +//! Clean residual corpus artifacts left by the original `;`-only extractor in the +//! corpus files that have no upstream reconstruction tool (issue #22, the long +//! tail). The reconstructed SQLite/Spark/Oracle suites are rebuilt by +//! `build_sqlite_suite` / `build_proc_suites`; this pass repairs the rest in +//! place on the unpacked `datasets/`. +//! +//! Two transforms, both conservative (they never invent SQL and only ever drop a +//! line that cannot be a valid standalone statement): +//! +//! 1. T-SQL `GO` batch separators. `GO` is a sqlcmd/SSMS client directive, not +//! T-SQL grammar. The extractor split on `;`, so `GO` lines with no semicolon +//! were glued onto the next statement (`GO SELECT ...`) or sat between two +//! statements on one line (`... GO ...`). The real SQL Server oracle accepts +//! `GO `, so every parser that correctly rejects `GO` was charged a +//! false recall failure. We split each line on top-level `GO` tokens, +//! recovering the real statements. Applied to the `tsql` corpus and the mixed +//! `multi` corpus (which also carries T-SQL GO batches). +//! +//! 2. Pure procedural fragments (all dialects). Lines that are only a block +//! keyword (`END IF`, `END LOOP`, `END TRY`, `BEGIN CATCH`, ...) or that start +//! with a clause keyword that can never begin a statement (`ELSE`, `ELSIF`, +//! `WHEN`, `THEN`, `AND`, `OR`, `LOOP`) are body pieces of a split +//! `CREATE FUNCTION`/`PROCEDURE`/batch and are dropped. Bare `END`/`END;` is +//! kept for SQLite, where `END` is a COMMIT synonym, but dropped elsewhere. +//! `DELIMITER` client directives (any dialect) and MySQL/`multi` `//` +//! routine-delimiter fragments are dropped too. The prefix rule also catches +//! the string-literal fragments that begin mid-prose (`And then my heart ...`, +//! `loop will exit ...`). +//! +//! A general multi-line string-literal repair was considered and rejected: the +//! corpus mixes `''`-doubling and backslash-escaping dialects (plus PG `E'...'` +//! and dollar-quoting), so a quote scanner mislabels valid statements wholesale. +//! The few genuine string fragments that remain are mostly provenance-only noise. +//! +//! Run `--apply` to write; otherwise it is a dry run reporting counts and samples. +//! After applying, repack `datasets.tar.zst` and re-run the T-SQL oracle (the +//! `GO` split produces new statement strings that need fresh labels). + +#![allow( + clippy::doc_markdown, + clippy::too_many_lines, + clippy::items_after_statements +)] + +use std::collections::HashSet; +use std::fs; +use std::path::Path; + +use sql_ast_benchmark::datasets::{ensure_corpus, Dialect}; + +/// Split a T-SQL line on top-level `GO` batch separators (a `GO` token bounded by +/// whitespace or line edges, outside any quote). Returns the recovered statement +/// pieces (callers normalize/drop empties). Quote tracking (`'...'` with `''` +/// escaping, `"..."`, `[...]`) keeps a `GO` inside a literal from splitting; T-SQL +/// does not use backslash escapes, so `''`/`""` doubling is the only escape. +fn split_go(line: &str) -> Vec { + let chars: Vec = line.chars().collect(); + let mut pieces = Vec::new(); + let mut buf = String::new(); + let mut i = 0; + while i < chars.len() { + let c = chars[i]; + // Consume a quoted literal verbatim so a `GO` inside it is not a split. + if matches!(c, '\'' | '"' | '[') { + let close = if c == '[' { ']' } else { c }; + buf.push(c); + i += 1; + while i < chars.len() { + let d = chars[i]; + if d == close { + if close != ']' && chars.get(i + 1) == Some(&close) { + buf.push(d); + buf.push(d); + i += 2; + continue; + } + buf.push(d); + i += 1; + break; + } + buf.push(d); + i += 1; + } + continue; + } + // A `GO` token: preceded by start-or-space, followed by space-or-end. + let at_boundary = i == 0 || chars[i - 1].is_whitespace(); + if at_boundary + && (c == 'G' || c == 'g') + && matches!(chars.get(i + 1), Some('O' | 'o')) + && chars.get(i + 2).is_none_or(|n| n.is_whitespace()) + { + pieces.push(std::mem::take(&mut buf)); + i += 2; + continue; + } + buf.push(c); + i += 1; + } + pieces.push(buf); + pieces +} + +fn normalize(s: &str) -> String { + s.split_whitespace().collect::>().join(" ") +} + +/// Whether a balanced line is a pure procedural fragment that should be dropped. +fn is_procedural_fragment(line: &str, dialect: Dialect) -> bool { + let up = normalize(line).to_ascii_uppercase(); + let bare = up.trim_end_matches(';').trim_end(); + + // `END` and `END ` (END IF / END LOOP / END CASE / END