diff --git a/postgresql_proxy/interceptors.py b/postgresql_proxy/interceptors.py index 60dc937..d4fefcd 100644 --- a/postgresql_proxy/interceptors.py +++ b/postgresql_proxy/interceptors.py @@ -55,14 +55,20 @@ def intercept(self, packet_type, data): # Query, ends with b'\x00' data = self._intercept_query(data, ic_queries) elif packet_type == b"P": - # Statement that needs parsing. - # First byte of the body is some Statement flag. Ignore, don't lose - # Next is the query, same as above, ends with an b'\x00' - # Last 2 bytes are the number of parameters. Ignore, don't lose - statement = data[0:1] - query = self._intercept_query(data[1:-2], ic_queries) - params = data[-2:] - data = statement + query + params + # Parse packet body: + # statement_name\x00 + query\x00 + int16(param_count) + uint32[] + # Keep the binary suffix untouched (count + OID array). + statement_end = data.find(b"\x00") + if statement_end != -1: + query_start = statement_end + 1 + query_end = data.find(b"\x00", query_start) + if query_end != -1: + statement = data[:query_start] + query = self._intercept_query( + data[query_start : query_end + 1], ic_queries + ) + params = data[query_end + 1 :] + data = statement + query + params if packet_type == b"": # Connection request / context. Ignore the first 4 bytes, keep it diff --git a/requirements-test.txt b/requirements-test.txt index a0e00db..9f3b999 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,3 @@ pytest==9.0.3 pytest-timeout==2.4.0 +psycopg[binary]==3.3.4 diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 19175c1..be47559 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -8,6 +8,7 @@ import threading import time +import psycopg import psycopg2 import pytest @@ -332,3 +333,72 @@ def test_psql_ssl_file_batch_stress_no_hang(postgres_settings, ssl_proxy_port): "psql -f batch succeeded but expected marker missing " f"(run={run_idx + 1}, {elapsed=:.2f}s) stdout_tail={out_tail}" ) + + +def test_extended_query_protocol_parse_packet_with_high_oid_params_passes_through_proxy( + postgres_settings, plain_proxy_port +): + """Regression: proxy must not corrupt Extended Query Protocol Parse packets. + + psycopg v3 sends Parse → Bind → Execute for parameterized queries. The Parse body + ends with binary uint32 OIDs; jsonb OID 3802 (0x00000EDA) contains 0xDA which is + not valid UTF-8. The old interceptor sliced the body incorrectly and crashed on + decode, causing the connection to hang or drop. + """ + with psycopg.connect( + host="127.0.0.1", + port=plain_proxy_port, + user=postgres_settings["user"], + password=postgres_settings["password"], + dbname=postgres_settings["dbname"], + sslmode="disable", + ) as conn: + with conn.cursor() as cur: + cur.execute( + "DROP TABLE IF EXISTS _test_jsonb_proxy_params;" + "CREATE TABLE _test_jsonb_proxy_params " + "(id serial PRIMARY KEY, data jsonb, label text);" + ) + + cur.execute( + "INSERT INTO _test_jsonb_proxy_params (data, label) " + "VALUES (%s, %s) RETURNING id", + (psycopg.types.json.Jsonb({"key": "value"}), "hello"), + ) + row = cur.fetchone() + + assert row is not None and row[0] >= 1 + + +def test_extended_query_protocol_named_prepared_statement_passes_through_proxy( + postgres_settings, plain_proxy_port +): + """Parse packets with a non-empty statement name must also be relayed correctly. + + The statement_name field precedes the query text in the Parse body. The fix uses + find(b'\\x00') to locate boundaries, so named statements work the same as anonymous + ones (empty name). + """ + with psycopg.connect( + host="127.0.0.1", + port=plain_proxy_port, + user=postgres_settings["user"], + password=postgres_settings["password"], + dbname=postgres_settings["dbname"], + sslmode="disable", + # Prepare after the first execution of the same query (i.e. on 2nd run). + prepare_threshold=1, + ) as conn: + with conn.cursor() as cur: + # Execute twice so psycopg can promote the query to a named statement. + for val in (1, 2): + cur.execute("SELECT %s::int + 1", (val,)) + result = cur.fetchone() + assert result == (val + 1,) + + # Verify psycopg created a named prepared statement in this session. + cur.execute( + "SELECT count(*) FROM pg_prepared_statements WHERE name LIKE '_pg3_%'" + ) + prepared_count = cur.fetchone() + assert prepared_count is not None and prepared_count[0] >= 1