Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions sqlparser_bench/benches/sqlparser_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use criterion::{criterion_group, criterion_main, Criterion};
use sqlparser::dialect::GenericDialect;
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect, SQLiteDialect};
use sqlparser::keywords::Keyword;
use sqlparser::parser::Parser;
use sqlparser::tokenizer::{Span, Word};
Expand Down Expand Up @@ -177,11 +177,58 @@ fn parse_compound_chain(c: &mut Criterion) {
group.finish();
}

/// Benchmark parsing pathological `IF(<keyword-fn>(<keyword-fn>(...x` chains
/// that previously caused 2^N work in `parse_prefix`. Each nested
/// `current_time(` segment used to be explored twice at every level (once via
/// the speculative reserved-word arm, once via the unreserved-word fallback),
/// doubling work per level. Post-fix the cost is linear in chain length.
fn parse_prefix_keyword_call_chain(c: &mut Criterion) {
let mut group = c.benchmark_group("parse_prefix_keyword_call_chain");
let dialect = PostgreSqlDialect {};

for &n in &[10usize, 20, 30] {
let sql = String::from("if(") + &"current_time(".repeat(n) + "x";

group.bench_function(format!("chain_{n}"), |b| {
b.iter(|| {
let _ = Parser::parse_sql(&dialect, std::hint::black_box(&sql));
});
});
}

group.finish();
}

/// Benchmark parsing pathological `case-case-case-...c` chains that
/// previously caused 2^N work in `parse_prefix`. Each `case` token used to
/// trigger a speculative `parse_case_expr` that recursively descends the
/// chain, but the unreserved-word fallback returns `Identifier(case)` so the
/// overall `parse_prefix` succeeds and the failure cache never fires.
/// Post-fix the per-arm cache short-circuits the speculative descent.
fn parse_prefix_case_chain(c: &mut Criterion) {
let mut group = c.benchmark_group("parse_prefix_case_chain");
let dialect = SQLiteDialect {};

for &n in &[10usize, 20, 30] {
let sql = "case\t-".repeat(n) + "c";

group.bench_function(format!("chain_{n}"), |b| {
b.iter(|| {
let _ = Parser::parse_sql(&dialect, std::hint::black_box(&sql));
});
});
}

group.finish();
}

criterion_group!(
benches,
basic_queries,
word_to_ident,
parse_many_identifiers,
parse_compound_chain
parse_compound_chain,
parse_prefix_keyword_call_chain,
parse_prefix_case_chain
);
criterion_main!(benches);
49 changes: 48 additions & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#[cfg(not(feature = "std"))]
use alloc::{
boxed::Box,
collections::BTreeMap,
format,
string::{String, ToString},
vec,
Expand All @@ -24,6 +25,9 @@ use core::{
fmt::{self, Display},
str::FromStr,
};
#[cfg(feature = "std")]
use std::collections::BTreeMap;

use helpers::attached_token::AttachedToken;

use log::debug;
Expand Down Expand Up @@ -359,6 +363,12 @@ pub struct Parser<'a> {
options: ParserOptions,
/// Ensures the stack does not overflow by limiting recursion depth.
recursion_counter: RecursionCounter,
/// Cached errors from `parse_prefix` calls that returned `Err`. See
/// [`Parser::parse_prefix`] for the 2^N patterns this guards.
failed_prefix_positions: BTreeMap<usize, ParserError>,
/// Cached errors from the speculative reserved-word prefix arm. See
/// [`Parser::parse_prefix`] for the 2^N patterns this guards.
failed_reserved_word_prefix_positions: BTreeMap<usize, ParserError>,
}

impl<'a> Parser<'a> {
Expand All @@ -385,6 +395,8 @@ impl<'a> Parser<'a> {
dialect,
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()),
failed_prefix_positions: BTreeMap::new(),
failed_reserved_word_prefix_positions: BTreeMap::new(),
}
}

Expand Down Expand Up @@ -446,6 +458,8 @@ impl<'a> Parser<'a> {
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithSpan>) -> Self {
self.tokens = tokens;
self.index = 0;
self.failed_prefix_positions.clear();
self.failed_reserved_word_prefix_positions.clear();
self
}

Expand Down Expand Up @@ -1717,6 +1731,23 @@ impl<'a> Parser<'a> {
return prefix;
}

// Memoize parse_prefix failures to break 2^N speculation when both
// prefix arms fail at every level (e.g. `IF(current_time(...x`).
// The per-arm cache in `parse_prefix_inner` complements this for
// chains where the reserved arm fails but the unreserved fallback
// succeeds (e.g. `case-case-...c`).
let start_index = self.index;
if let Some(cached) = self.failed_prefix_positions.get(&start_index) {
return Err(cached.clone());
}
let result = self.parse_prefix_inner();
if let Err(ref e) = result {
self.failed_prefix_positions.insert(start_index, e.clone());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering about memory usage, how large do we expect the caches to get in the worse case scenarios?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a ~2GB limit on the fuzzer processes and after many billions of iterations it did not manage to OOM - it hit before the ASAN thread number limit than an OOM, so at the very least I believe it is not exponential. I can try and get a scaling though.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah I'm mostly worried about the increased memory usage. A lot of deployments don't have e.g. 2GB to allocate to parsing a sql query. I think the main issue is that the map contains strings, maybe if its a much cheaper/copy object somehow we were looking to store, it might be feasible.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Got a scaling for the cache memory. The two caches are keyed by token index, so each entry is about 40 bytes (8 byte usize key plus a 32 byte ParserError). Panels plot entries, heap String bytes, and total memory vs chain length N for valid SQL, the nested if(current_time(...x chain, and the wide case-...c chain.

  • On valid SQL both maps stay empty (0 entries at a 27 KiB SELECT).
  • Heap String bytes are bounded by recursion depth, not input: past the recursion limit the cached error is RecursionLimitExceeded with no String, so strings peak at ~1.8 KiB and never grow with input.
  • Total memory is linear with a small constant: capped at ~4 KiB for deep nesting, at most ~13x input for the adversarial wide shape (159 KiB at a 12 KiB input). Reaching 2 GB would need a ~154 MB single statement.

So it is linear and never exponential, and valid SQL costs nothing.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid SQL costs nothing

My understanding is that e.g. on the parse_expr_prefix_by_reserved_word path, each unique attempt will add an error string into the map, if so then we're potentially looking at one entry in that map per word roughly. is that the the case?

To be clear, its not about a particular amount of memory, main thing is that we're not increasing memory usage of the parser significantly - if the additional memory usage grows as a function of the sql string, already that is problematic, then to improve it we would like to have each entry as minimal as we can, or potentially consider other solutions

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow, currently the runtime is exponential, so the parser is just failing on these inputs. With the fix, the time requirement becomes linear and the the memory requirements remain linear with input size even for pathological cases

}
result
}

fn parse_prefix_inner(&mut self) -> Result<Expr, ParserError> {
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
// string literal represents a literal of that type. Some examples:
//
Expand Down Expand Up @@ -1801,7 +1832,21 @@ impl<'a> Parser<'a> {
// We first try to parse the word and following tokens as a special expression, and if that fails,
// we rollback and try to parse it as an identifier.
let w = w.clone();
match self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w, span)) {
// Memoize failed speculative reserved-word parses. When
// the reserved arm (CASE, CURRENT_TIME, etc.) does
// exponential work but the unreserved fallback ultimately
// succeeds, the overall `parse_prefix` returns `Ok` and the
// outer cache never fires. Chains like `case-case-...c`
// need this per-arm cache to break the doubling.
Comment on lines +1835 to +1840
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to follow this comment, what is 'outer cache' and 'break doubling' referring to?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outer cache is the cache in the parent call before the recursive call, and doubling is that at each layer of depth it went in before, the operations doubled (hence the previous 2^layers complexity from before)

let try_parse_result = if let Some(cached) = self
.failed_reserved_word_prefix_positions
.get(&next_token_index)
{
Err(cached.clone())
} else {
self.try_parse(|parser| parser.parse_expr_prefix_by_reserved_word(&w, span))
};
match try_parse_result {
// This word indicated an expression prefix and parsing was successful
Ok(Some(expr)) => Ok(expr),

Expand All @@ -1815,6 +1860,8 @@ impl<'a> Parser<'a> {
// we rollback and return the parsing error we got from trying to parse a
// special expression (to maintain backwards compatibility of parsing errors).
Err(e) => {
self.failed_reserved_word_prefix_positions
.insert(next_token_index, e.clone());
if !self.dialect.is_reserved_for_identifier(w.keyword) {
if let Ok(Some(expr)) = self.maybe_parse(|parser| {
parser.parse_expr_prefix_by_unreserved_word(&w, span)
Expand Down
46 changes: 46 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19004,3 +19004,49 @@ fn parse_compound_chain_no_exponential_blowup() {
rx.recv_timeout(Duration::from_secs(5))
.expect("parser should reject this quickly, not loop exponentially");
}

/// Regression test for the 2^N parse-time blowup in `parse_prefix` on inputs
/// like `IF(current_time(current_time(...x`. Each nested `current_time(` used
/// to be explored twice at every level (once via the speculative reserved-word
/// arm, once via the unreserved-word fallback), doubling work per level.
/// Post-fix the failing parse short-circuits via the position-keyed cache.
#[test]
fn parse_prefix_keyword_call_chain_no_exponential_blowup() {
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

let sql = String::from("if(") + &"current_time(".repeat(30) + "x";

let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let _ = Parser::parse_sql(&PostgreSqlDialect {}, &sql);
let _ = tx.send(());
});

rx.recv_timeout(Duration::from_secs(5))
.expect("parser should reject this quickly, not loop exponentially");
}

/// Regression test for the 2^N parse-time blowup in `parse_prefix` on inputs
/// like `case-case-case-...c`. Each `case` token triggers a speculative
/// `parse_case_expr` that fails, but the unreserved-word fallback returns
/// `Identifier(case)`, so the outer failure cache never fires. Post-fix the
/// per-arm cache short-circuits the speculative descent.
#[test]
fn parse_prefix_case_chain_no_exponential_blowup() {
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

let sql = "case\t-".repeat(30) + "c";

let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let _ = Parser::parse_sql(&SQLiteDialect {}, &sql);
let _ = tx.send(());
});

rx.recv_timeout(Duration::from_secs(5))
.expect("parser should reject this quickly, not loop exponentially");
}
Loading