diff --git a/pgdog/src/frontend/router/parser/insert.rs b/pgdog/src/frontend/router/parser/insert.rs deleted file mode 100644 index 99bb61d3..00000000 --- a/pgdog/src/frontend/router/parser/insert.rs +++ /dev/null @@ -1,362 +0,0 @@ -//! Handle INSERT statements. -use pg_query::{protobuf::*, NodeEnum}; -use tracing::debug; - -use crate::{ - backend::ShardingSchema, - frontend::router::{ - round_robin, - sharding::{ContextBuilder, Tables, Value as ShardingValue}, - }, - net::Bind, -}; - -use super::{Column, Error, Shard, Table, Tuple, Value}; - -/// Parse an `INSERT` statement. -#[derive(Debug)] -pub struct Insert<'a> { - stmt: &'a InsertStmt, -} - -impl<'a> Insert<'a> { - /// Parse an `INSERT` statement. - pub fn new(stmt: &'a InsertStmt) -> Self { - Self { stmt } - } - - /// Get columns, if any are specified. - pub fn columns(&'a self) -> Vec> { - self.stmt - .cols - .iter() - .map(Column::try_from) - .collect::>, Error>>() - .ok() - .unwrap_or(vec![]) - } - - /// Get table name, if specified (should always be). - pub fn table(&'a self) -> Option> { - self.stmt.relation.as_ref().map(Table::from) - } - - /// Get rows from the statement. - pub fn tuples(&'a self) -> Vec> { - if let Some(select) = &self.stmt.select_stmt { - if let Some(NodeEnum::SelectStmt(stmt)) = &select.node { - let tuples = stmt - .values_lists - .iter() - .map(Tuple::try_from) - .collect::>, ()>>(); - return tuples.unwrap_or(vec![]); - } - } - - vec![] - } - - /// Calculate the number of tuples in the statement. - pub fn num_tuples(&self) -> usize { - if let Some(select) = &self.stmt.select_stmt { - if let Some(NodeEnum::SelectStmt(stmt)) = &select.node { - return stmt.values_lists.len(); - } - } - - 0 - } - - /// Get the sharding key for the statement. - pub fn shard( - &'a self, - schema: &'a ShardingSchema, - bind: Option<&Bind>, - ) -> Result { - let tables = Tables::new(schema); - let columns = self.columns(); - let table = self.table(); - let tuples = self.tuples(); - - let key = table.and_then(|table| tables.key(table, &columns)); - - if let Some(table) = table { - // Schema-based routing. - if let Some(schema) = schema.schemas.get(table.schema()) { - return Ok(schema.shard().into()); - } - } - - if self.num_tuples() != 1 { - debug!("multiple tuples in an INSERT statement"); - return Ok(Shard::All); - } - - if let Some(key) = key { - if let Some(bind) = bind { - if let Ok(Some(param)) = bind.parameter(key.position) { - if param.is_null() { - return Ok(Shard::All); - } else { - // Arrays not supported as sharding keys at the moment. - let value = ShardingValue::from_param(¶m, key.table.data_type)?; - let ctx = ContextBuilder::new(key.table) - .value(value) - .shards(schema.shards) - .build()?; - return Ok(ctx.apply()?); - } - } - } - - if let Some(value) = tuples.first().and_then(|tuple| tuple.get(key.position)) { - match value { - Value::Integer(int) => { - let ctx = ContextBuilder::new(key.table) - .data(*int) - .shards(schema.shards) - .build()?; - return Ok(ctx.apply()?); - } - - Value::String(str) => { - let ctx = ContextBuilder::new(key.table) - .data(*str) - .shards(schema.shards) - .build()?; - return Ok(ctx.apply()?); - } - - _ => (), - } - } - } else if let Some(table) = table { - // If this table is sharded, but the sharding key isn't in the query, - // choose a shard at random. - if tables.sharded(table).is_some() { - return Ok(Shard::Direct(round_robin::next() % schema.shards)); - } - } - - Ok(Shard::All) - } -} - -#[cfg(test)] -mod test { - use pg_query::{parse, NodeEnum}; - - use crate::backend::ShardedTables; - use crate::config::ShardedTable; - use crate::net::bind::Parameter; - use crate::net::Format; - use bytes::Bytes; - - use super::super::Value; - use super::*; - - #[test] - fn test_insert() { - let query = parse("INSERT INTO my_table (id, email) VALUES (1, 'test@test.com')").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - assert_eq!( - insert.table(), - Some(Table { - name: "my_table", - schema: None, - alias: None, - }) - ); - assert_eq!( - insert.columns(), - vec![ - Column { - name: "id", - ..Default::default() - }, - Column { - name: "email", - ..Default::default() - } - ] - ); - } - - _ => panic!("not an insert"), - } - } - - #[test] - fn test_insert_params() { - let query = parse("INSERT INTO my_table (id, email) VALUES ($1, $2)").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - assert_eq!( - insert.tuples(), - vec![Tuple { - values: vec![Value::Placeholder(1), Value::Placeholder(2),] - }] - ) - } - - _ => panic!("not an insert"), - } - } - - #[test] - fn test_insert_typecasts() { - let query = - parse("INSERT INTO sharded (id, value) VALUES ($1::INTEGER, $2::VARCHAR)").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - assert_eq!( - insert.tuples(), - vec![Tuple { - values: vec![Value::Placeholder(1), Value::Placeholder(2),] - }] - ) - } - - _ => panic!("not an insert"), - } - } - - #[test] - fn test_shard_insert() { - let query = parse("INSERT INTO sharded (id, value) VALUES (1, 'test')").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - let schema = ShardingSchema { - shards: 3, - tables: ShardedTables::new( - vec![ - ShardedTable { - name: Some("sharded".into()), - column: "id".into(), - ..Default::default() - }, - ShardedTable { - name: None, - column: "user_id".into(), - ..Default::default() - }, - ], - vec![], - false, - ), - ..Default::default() - }; - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - let shard = insert.shard(&schema, None).unwrap(); - assert!(matches!(shard, Shard::Direct(2))); - - let bind = Bind::new_params( - "", - &[Parameter { - len: 1, - data: "3".as_bytes().into(), - }], - ); - - let shard = insert.shard(&schema, Some(&bind)).unwrap(); - assert!(matches!(shard, Shard::Direct(1))); - - let bind = Bind::new_params_codes( - "", - &[Parameter { - len: 8, - data: Bytes::copy_from_slice(&234_i64.to_be_bytes()), - }], - &[Format::Binary], - ); - - let shard = insert.shard(&schema, Some(&bind)).unwrap(); - assert!(matches!(shard, Shard::Direct(0))); - } - - _ => panic!("not an insert"), - } - - let query = parse("INSERT INTO orders (user_id, value) VALUES (1, 'test')").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - let shard = insert.shard(&schema, None).unwrap(); - assert!(matches!(shard, Shard::Direct(2))); - } - - _ => panic!("not a select"), - } - - let query = parse("INSERT INTO random_table (users_id, value) VALUES (1, 'test')").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - let shard = insert.shard(&schema, None).unwrap(); - assert!(matches!(shard, Shard::All)); - } - - _ => panic!("not a select"), - } - - // Round robin test. - let query = parse("INSERT INTO sharded (value) VALUES ('test')").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - let shard = insert.shard(&schema, None).unwrap(); - assert!(matches!(shard, Shard::Direct(_))); - } - - _ => panic!("not a select"), - } - } - - #[test] - fn test_null_sharding_key_routes_to_all() { - let query = parse("INSERT INTO sharded (id, value) VALUES ($1, 'test')").unwrap(); - let select = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); - let schema = ShardingSchema { - shards: 3, - tables: ShardedTables::new( - vec![ShardedTable { - name: Some("sharded".into()), - column: "id".into(), - ..Default::default() - }], - vec![], - false, - ), - ..Default::default() - }; - - match &select.node { - Some(NodeEnum::InsertStmt(stmt)) => { - let insert = Insert::new(stmt); - let bind = Bind::new_params("", &[Parameter::new_null()]); - let shard = insert.shard(&schema, Some(&bind)).unwrap(); - assert!(matches!(shard, Shard::All)); - } - _ => panic!("not an insert"), - } - } -} diff --git a/pgdog/src/frontend/router/parser/mod.rs b/pgdog/src/frontend/router/parser/mod.rs index 8fad8c44..4ffdbe98 100644 --- a/pgdog/src/frontend/router/parser/mod.rs +++ b/pgdog/src/frontend/router/parser/mod.rs @@ -15,7 +15,6 @@ pub mod explain_trace; mod expression; pub mod from_clause; pub mod function; -pub mod insert; pub mod key; pub mod limit; pub mod multi_tenant; @@ -46,7 +45,6 @@ pub use distinct::{Distinct, DistinctBy, DistinctColumn}; pub use error::Error; pub use function::Function; pub use function::{FunctionBehavior, LockingBehavior}; -pub use insert::Insert; pub use key::Key; pub use limit::{Limit, LimitClause}; pub use order_by::OrderBy; diff --git a/pgdog/src/frontend/router/parser/query/delete.rs b/pgdog/src/frontend/router/parser/query/delete.rs index f6163130..3bc61c80 100644 --- a/pgdog/src/frontend/router/parser/query/delete.rs +++ b/pgdog/src/frontend/router/parser/query/delete.rs @@ -1,4 +1,3 @@ -use super::StatementParser; use super::*; impl QueryParser { diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index e953b4ae..79ecfab3 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -507,10 +507,17 @@ impl QueryParser { stmt: &InsertStmt, context: &mut QueryParserContext, ) -> Result { - let insert = Insert::new(stmt); - context.shards_calculator.push(ShardWithPriority::new_table( - insert.shard(&context.sharding_schema, context.router_context.bind)?, - )); + let mut parser = StatementParser::from_insert( + stmt, + context.router_context.bind, + &context.sharding_schema, + self.recorder_mut(), + ); + let shard = parser.shard()?.unwrap_or(Shard::All); + + context + .shards_calculator + .push(ShardWithPriority::new_table(shard.clone())); let shard = context.shards_calculator.shard(); if let Some(recorder) = self.recorder_mut() { diff --git a/pgdog/src/frontend/router/parser/statement.rs b/pgdog/src/frontend/router/parser/statement.rs index 49ab3959..58ddc3ce 100644 --- a/pgdog/src/frontend/router/parser/statement.rs +++ b/pgdog/src/frontend/router/parser/statement.rs @@ -13,9 +13,11 @@ use super::{ }; use crate::{ backend::{Schema, ShardingSchema}, + config::ShardedTable, frontend::router::{ parser::Shard, - sharding::{ContextBuilder, SchemaSharder}, + round_robin, + sharding::{ContextBuilder, SchemaSharder, Tables}, }, net::{parameter::ParameterValue, Bind}, }; @@ -594,12 +596,31 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { } } + /// Find sharded table config for a column. + /// Named configs (with explicit table names) match specific table+column. + /// Column-only configs match any table with that column name. + fn get_sharded_table(&self, column: Column<'a>) -> Option<&ShardedTable> { + // Try named table configs first + if let Some(sharded_table) = self.schema.tables().get_table(column) { + if sharded_table.name.is_some() { + return Some(sharded_table); + } + } + + // Column-only config: user explicitly wants any table with this column to be sharded + self.schema + .tables + .tables() + .iter() + .find(|t| t.name.is_none() && t.column == column.name) + } + fn compute_shard( &mut self, column: Column<'a>, value: Value<'a>, ) -> Result, Error> { - if let Some(table) = self.schema.tables().get_table(column) { + if let Some(table) = self.get_sharded_table(column) { let context = ContextBuilder::new(table); let shard = match value { Value::Placeholder(pos) => { @@ -614,6 +635,10 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { } else { return Ok(None); }; + // NULL sharding key broadcasts to all shards + if param.is_null() { + return Ok(Some(Shard::All)); + } let value = ShardingValue::from_param(¶m, table.data_type)?; Some( context @@ -639,7 +664,7 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { .build()? .apply()?, ), - Value::Null => None, + Value::Null => return Ok(Some(Shard::All)), _ => None, }; @@ -755,7 +780,7 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { // parse array literals or parameters, so route to all shards. if is_any && matches!(values, SearchResult::Value(_)) - && self.schema.tables().get_table(column).is_some() + && self.get_sharded_table(column).is_some() { return Ok(SearchResult::Match(Shard::All)); } @@ -1012,6 +1037,69 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { stmt: &'a InsertStmt, ctx: &SearchContext<'a>, ) -> Result, Error> { + // Schema-based routing takes priority for INSERTs + if let Some(table) = ctx.table { + if let Some(schema) = self.schema.schemas.get(table.schema()) { + return Ok(SearchResult::Match(schema.shard().into())); + } + } + + // Get the column names from INSERT INTO table (col1, col2, ...) + let columns: Vec<&str> = stmt + .cols + .iter() + .filter_map(|node| match &node.node { + Some(NodeEnum::ResTarget(target)) => Some(target.name.as_str()), + _ => None, + }) + .collect(); + + // Handle different INSERT forms + if let Some(ref select_node) = stmt.select_stmt { + if let Some(NodeEnum::SelectStmt(ref select_stmt)) = select_node.node { + // Multi-row VALUES broadcasts to all shards + if select_stmt.values_lists.len() > 1 { + return Ok(SearchResult::Match(Shard::All)); + } + + // INSERT...SELECT (no VALUES): try to extract sharding key from target list + if select_stmt.values_lists.is_empty() { + // Try to extract constants from SELECT target list + if !select_stmt.target_list.is_empty() { + for (pos, target_node) in select_stmt.target_list.iter().enumerate() { + if let Some(NodeEnum::ResTarget(ref target)) = target_node.node { + if let Some(column_name) = columns.get(pos) { + let column = Column { + name: column_name, + table: ctx.table.map(|t| t.name), + schema: ctx.table.and_then(|t| t.schema), + }; + + if self.get_sharded_table(column).is_some() { + if let Some(ref val) = target.val { + if let Ok(value) = Value::try_from(val.as_ref()) { + if let Some(shard) = + self.compute_shard_with_ctx(column, value, ctx)? + { + return Ok(SearchResult::Match(shard)); + } + } + } + } + } + } + } + } + + // INSERT...SELECT without extractable key broadcasts + return Ok(SearchResult::Match(Shard::All)); + } + } + } else { + // No select_stmt (DEFAULT VALUES) broadcasts to all shards + return Ok(SearchResult::Match(Shard::All)); + } + // Handle CTEs (WITH clause) if let Some(ref with_clause) = stmt.with_clause { for cte in &with_clause.ctes { @@ -1049,7 +1137,7 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { schema: ctx.table.and_then(|t| t.schema), }; - if self.schema.tables().get_table(column).is_some() { + if self.get_sharded_table(column).is_some() { // Try to extract the value directly if let Ok(value) = Value::try_from(value_node) { if let Some(shard) = @@ -1070,12 +1158,17 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { } } } + } + } - // Handle INSERT ... SELECT by recursively searching the SelectStmt - let result = self.select_search(select_node, ctx)?; - if !result.is_none() { - return Ok(result); - } + // Round-robin fallback: if table is sharded but no sharding key found, + // pick a shard at random + if let Some(table) = ctx.table { + let tables = Tables::new(self.schema); + if tables.sharded(table).is_some() { + return Ok(SearchResult::Match(Shard::Direct( + round_robin::next() % self.schema.shards, + ))); } } @@ -1841,11 +1934,84 @@ mod test { } #[test] - fn test_insert_no_sharding_key_returns_none() { + fn test_insert_no_sharding_key_uses_round_robin() { + // When sharding key is missing but table is sharded, use round-robin let result = run_test("INSERT INTO sharded (name) VALUES ('foo')", None); + assert!(matches!(result.unwrap(), Some(Shard::Direct(_)))); + } + + #[test] + fn test_insert_multi_row_broadcasts() { + // Multi-row INSERTs should broadcast to all shards + let result = run_test( + "INSERT INTO sharded (id, name) VALUES (1, 'foo'), (2, 'bar')", + None, + ); + assert_eq!(result.unwrap(), Some(Shard::All)); + } + + #[test] + fn test_insert_multi_row_with_params_broadcasts() { + // Multi-row INSERTs with params should also broadcast + let bind = Bind::new_params( + "", + &[ + Parameter::new(b"1"), + Parameter::new(b"foo"), + Parameter::new(b"2"), + Parameter::new(b"bar"), + ], + ); + let result = run_test( + "INSERT INTO sharded (id, name) VALUES ($1, $2), ($3, $4)", + Some(&bind), + ); + assert_eq!(result.unwrap(), Some(Shard::All)); + } + + #[test] + fn test_insert_unsharded_table_returns_none() { + // Unsharded table should return None (not round-robin) + let result = run_test("INSERT INTO unsharded_table (name) VALUES ('foo')", None); assert!(result.unwrap().is_none()); } + #[test] + fn test_insert_select_with_constant() { + // INSERT ... SELECT where the sharding key is a constant in the SELECT target list + let result = run_test("INSERT INTO sharded (id, name) SELECT 1, 'test'", None); + assert!(matches!(result.unwrap(), Some(Shard::Direct(_)))); + } + + #[test] + fn test_insert_select_with_constant_param() { + // INSERT ... SELECT where the sharding key is a parameter in the SELECT target list + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "INSERT INTO sharded (id, name) SELECT $1, 'test'", + Some(&bind), + ); + assert!(matches!(result.unwrap(), Some(Shard::Direct(_)))); + } + + #[test] + fn test_insert_null_sharding_key_param_broadcasts() { + // NULL sharding key as param should broadcast to all shards + let bind = Bind::new_params("", &[Parameter::new_null(), Parameter::new(b"test")]); + let result = run_test( + "INSERT INTO sharded (id, name) VALUES ($1, $2)", + Some(&bind), + ); + assert_eq!(result.unwrap(), Some(Shard::All)); + } + + #[test] + fn test_insert_null_sharding_key_literal_broadcasts() { + // NULL sharding key as literal should broadcast to all shards + let result = run_test("INSERT INTO sharded (id, name) VALUES (NULL, 'test')", None); + assert_eq!(result.unwrap(), Some(Shard::All)); + } + // Schema-based sharding fallback tests use crate::backend::replication::ShardedSchemas; use pgdog_config::sharding::ShardedSchema; @@ -1976,4 +2142,114 @@ mod test { assert_eq!(result1, Some(Shard::Direct(1))); assert_eq!(result2, Some(Shard::Direct(2))); } + + // Column-only sharded table detection tests (using loaded schema) + + fn run_test_column_only(stmt: &str, bind: Option<&Bind>) -> Result, Error> { + // Use column-only sharded table config (no table name) + let schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ShardedTable { + column: "tenant_id".into(), + // No table name - column-only config + ..Default::default() + }], + vec![], + false, + ), + ..Default::default() + }; + let raw = pg_query::parse(stmt) + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + let mut parser = StatementParser::from_raw(&raw, bind, &schema, None)?; + parser.shard() + } + + #[test] + fn test_column_only_select() { + let result = run_test_column_only("SELECT * FROM users WHERE tenant_id = 1", None).unwrap(); + assert!(result.is_some(), "Should detect column-only sharding key"); + } + + #[test] + fn test_column_only_select_with_alias() { + let result = + run_test_column_only("SELECT * FROM users u WHERE u.tenant_id = 1", None).unwrap(); + assert!( + result.is_some(), + "Should detect column-only sharding key with alias" + ); + } + + #[test] + fn test_column_only_select_bound_param() { + let bind = Bind::new_params("", &[Parameter::new(b"42")]); + let result = + run_test_column_only("SELECT * FROM users WHERE tenant_id = $1", Some(&bind)).unwrap(); + assert!( + result.is_some(), + "Should detect column-only sharding key with bound param" + ); + } + + #[test] + fn test_column_only_update() { + let result = + run_test_column_only("UPDATE users SET name = 'foo' WHERE tenant_id = 1", None) + .unwrap(); + assert!( + result.is_some(), + "Should detect column-only sharding key in UPDATE" + ); + } + + #[test] + fn test_column_only_delete() { + let result = run_test_column_only("DELETE FROM users WHERE tenant_id = 1", None).unwrap(); + assert!( + result.is_some(), + "Should detect column-only sharding key in DELETE" + ); + } + + #[test] + fn test_column_only_insert() { + let result = run_test_column_only( + "INSERT INTO users (tenant_id, name) VALUES (1, 'foo')", + None, + ) + .unwrap(); + assert!( + result.is_some(), + "Should detect column-only sharding key in INSERT" + ); + } + + #[test] + fn test_column_only_any_table() { + // Column-only configs work with any table + let result = + run_test_column_only("SELECT * FROM unknown_table WHERE tenant_id = 1", None).unwrap(); + assert!( + result.is_some(), + "Column-only config should work with any table" + ); + } + + #[test] + fn test_column_only_wrong_column() { + // Column-only config shouldn't match different column name + let result = run_test_column_only("SELECT * FROM users WHERE other_id = 1", None).unwrap(); + assert!( + result.is_none(), + "Column-only config should not match different column, got {:?}", + result + ); + } }