diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index 8d247ff8f..ca8b69c55 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -23,6 +23,7 @@ pub mod notify_buffer; pub mod prepared_statements; pub mod pub_sub; pub mod query; +pub mod rewrite; pub mod route_query; pub mod set; pub mod show_shards; @@ -33,7 +34,7 @@ pub mod unknown_command; #[cfg(test)] mod testing; -use self::query::ExplainResponseState; +use self::{query::ExplainResponseState, rewrite::RewriteDriver}; pub use context::QueryEngineContext; use notify_buffer::NotifyBuffer; pub use two_pc::phase::TwoPcPhase; @@ -53,6 +54,7 @@ pub struct QueryEngine { two_pc: TwoPc, notify_buffer: NotifyBuffer, pending_explain: Option, + rewrite_driver: RewriteDriver, hooks: QueryEngineHooks, } @@ -136,7 +138,7 @@ impl QueryEngine { self.pending_explain = None; - let command = self.router.command(); + let command = self.router.command().clone(); let mut route = if let Some(ref route) = self.set_route { route.clone() } else { @@ -153,75 +155,72 @@ impl QueryEngine { context.client_request.route = Some(route.clone()); match command { - Command::Shards(shards) => self.show_shards(context, *shards).await?, + Command::Shards(shards) => self.show_shards(context, shards).await?, Command::StartTransaction { query, transaction_type, extended, } => { - self.start_transaction(context, query.clone(), *transaction_type, *extended) + self.start_transaction(context, query.clone(), transaction_type, extended) .await? } Command::CommitTransaction { extended } => { self.set_route = None; - if self.backend.connected() || *extended { - let extended = *extended; + if self.backend.connected() || extended { let transaction_route = self.transaction_route(&route)?; context.client_request.route = Some(transaction_route.clone()); context.cross_shard_disabled = Some(false); self.end_connected(context, &transaction_route, false, extended) .await?; } else { - self.end_not_connected(context, false, *extended).await? + self.end_not_connected(context, false, extended).await? } } Command::RollbackTransaction { extended } => { self.set_route = None; - if self.backend.connected() || *extended { - let extended = *extended; + if self.backend.connected() || extended { let transaction_route = self.transaction_route(&route)?; context.client_request.route = Some(transaction_route.clone()); context.cross_shard_disabled = Some(false); self.end_connected(context, &transaction_route, true, extended) .await?; } else { - self.end_not_connected(context, true, *extended).await? + self.end_not_connected(context, true, extended).await? } } Command::Query(_) => self.execute(context, &route).await?, - Command::Listen { channel, shard } => { - self.listen(context, &channel.clone(), shard.clone()) - .await? - } + Command::Listen { channel, shard } => self.listen(context, &channel, shard).await?, Command::Notify { channel, payload, shard, - } => { - self.notify(context, &channel.clone(), &payload.clone(), &shard.clone()) - .await? - } - Command::Unlisten(channel) => self.unlisten(context, &channel.clone()).await?, + } => self.notify(context, &channel, &payload, &shard).await?, + Command::Unlisten(channel) => self.unlisten(context, &channel).await?, Command::Set { name, value } => { if self.backend.connected() { self.execute(context, &route).await? } else { - self.set(context, name.clone(), value.clone()).await? + self.set(context, name, value).await? } } Command::SetRoute(route) => { - self.set_route(context, route.clone()).await?; + self.set_route(context, route).await?; } Command::Copy(_) => self.execute(context, &route).await?, Command::Rewrite(query) => { - context.client_request.rewrite(query)?; + context.client_request.rewrite(&query)?; self.execute(context, &route).await?; } + Command::PlannedRewrite(plan) => { + if let Some(handler) = self.rewrite_driver.handler(plan.kind()) { + handler(self, context, &plan)?; + } + } Command::Deallocate => self.deallocate(context).await?, - Command::Discard { extended } => self.discard(context, *extended).await?, - command => self.unknown_command(context, command.clone()).await?, + Command::Discard { extended } => self.discard(context, extended).await?, + command => self.unknown_command(context, command).await?, } self.hooks.after_execution(context)?; diff --git a/pgdog/src/frontend/client/query_engine/rewrite.rs b/pgdog/src/frontend/client/query_engine/rewrite.rs new file mode 100644 index 000000000..af23a7622 --- /dev/null +++ b/pgdog/src/frontend/client/query_engine/rewrite.rs @@ -0,0 +1,29 @@ +use std::collections::HashMap; + +use crate::frontend::router::rewrite::{RewriteExecutionKind, RewriteExecutionPlan}; + +use super::{Error, QueryEngine, QueryEngineContext}; + +type RewriteHandler = + fn(&mut QueryEngine, &mut QueryEngineContext<'_>, &RewriteExecutionPlan) -> Result<(), Error>; + +#[derive(Debug, Default)] +pub struct RewriteDriver { + handlers: HashMap, +} + +impl RewriteDriver { + pub fn new() -> Self { + Self { + handlers: HashMap::new(), + } + } + + pub fn register(&mut self, kind: RewriteExecutionKind, handler: RewriteHandler) { + self.handlers.insert(kind, handler); + } + + pub fn handler(&self, kind: &RewriteExecutionKind) -> Option { + self.handlers.get(kind).copied() + } +} diff --git a/pgdog/src/frontend/router/mod.rs b/pgdog/src/frontend/router/mod.rs index 888c95164..b245f7f50 100644 --- a/pgdog/src/frontend/router/mod.rs +++ b/pgdog/src/frontend/router/mod.rs @@ -4,6 +4,7 @@ pub mod context; pub mod copy; pub mod error; pub mod parser; +pub mod rewrite; pub mod round_robin; pub mod search_path; pub mod sharding; @@ -13,6 +14,7 @@ pub use error::Error; use lazy_static::lazy_static; use parser::Shard; pub use parser::{Command, QueryParser, Route}; +pub use rewrite::{RewriteExecutionKind, RewriteExecutionPlan, RewriteRegistry}; use super::ClientRequest; pub use context::RouterContext; diff --git a/pgdog/src/frontend/router/parser/command.rs b/pgdog/src/frontend/router/parser/command.rs index 038d930e2..fe66c841b 100644 --- a/pgdog/src/frontend/router/parser/command.rs +++ b/pgdog/src/frontend/router/parser/command.rs @@ -1,5 +1,6 @@ use super::*; use crate::{ + frontend::router::rewrite::RewriteExecutionPlan, frontend::{client::TransactionType, BufferedQuery}, net::parameter::ParameterValue, }; @@ -43,6 +44,7 @@ pub enum Command { }, Unlisten(String), SetRoute(Route), + PlannedRewrite(RewriteExecutionPlan), } impl Command { @@ -53,6 +55,7 @@ impl Command { match self { Self::Query(route) => route, + Self::PlannedRewrite(plan) => plan.route(), _ => &DEFAULT_ROUTE, } } diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index eb3633531..41b45745c 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -8,6 +8,7 @@ use crate::{ router::{ context::RouterContext, parser::{rewrite::Rewrite, OrderBy, Shard}, + rewrite::{PlannerContext, PlannerInput, RewritePlanner, RewriteRegistry}, round_robin, sharding::{Centroids, ContextBuilder, Value as ShardingValue}, }, @@ -67,6 +68,7 @@ pub struct QueryParser { // Plugin read override. plugin_output: PluginOutput, explain_recorder: Option, + rewrite_registry: RewriteRegistry, } impl Default for QueryParser { @@ -77,11 +79,16 @@ impl Default for QueryParser { shard: Shard::All, plugin_output: PluginOutput::default(), explain_recorder: None, + rewrite_registry: RewriteRegistry::new(), } } } impl QueryParser { + pub fn register_rewrite_planner(&mut self, planner: Box) { + self.rewrite_registry.add_planner(planner); + } + fn recorder_mut(&mut self) -> Option<&mut ExplainRecorder> { self.explain_recorder.as_mut() } @@ -132,11 +139,26 @@ impl QueryParser { Command::default() }; + if let Some(plan) = self.rewrite_registry.plan( + &PlannerContext::new(&qp_context), + &PlannerInput::Command(&command), + )? { + command = Command::PlannedRewrite(plan); + } + // If the cluster only has one shard, use direct-to-shard queries. - if let Command::Query(ref mut query) = command { - if !matches!(query.shard(), Shard::Direct(_)) && qp_context.shards == 1 { - query.set_shard_mut(0); + match &mut command { + Command::Query(ref mut query) => { + if !matches!(query.shard(), Shard::Direct(_)) && qp_context.shards == 1 { + query.set_shard_mut(0); + } + } + Command::PlannedRewrite(ref mut plan) => { + if !matches!(plan.route().shard(), Shard::Direct(_)) && qp_context.shards == 1 { + plan.route_mut().set_shard_mut(0); + } } + _ => {} } Ok(command) diff --git a/pgdog/src/frontend/router/rewrite/mod.rs b/pgdog/src/frontend/router/rewrite/mod.rs new file mode 100644 index 000000000..cde2f09e5 --- /dev/null +++ b/pgdog/src/frontend/router/rewrite/mod.rs @@ -0,0 +1,105 @@ +use crate::frontend::router::parser::{Command, Error as ParserError, QueryParserContext, Route}; + +/// Enum describing supported rewrite plans. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum RewriteExecutionKind { + /// Placeholder variant for future rewrites. + Unspecified, +} + +/// Planner output describing the rewrite to execute. +#[derive(Debug, Clone)] +pub struct RewriteExecutionPlan { + kind: RewriteExecutionKind, + route: Route, +} + +impl RewriteExecutionPlan { + pub fn new(kind: RewriteExecutionKind, route: Route) -> Self { + Self { kind, route } + } + + pub fn kind(&self) -> &RewriteExecutionKind { + &self.kind + } + + pub fn route(&self) -> &Route { + &self.route + } + + pub fn route_mut(&mut self) -> &mut Route { + &mut self.route + } +} + +/// Input provided to rewrite planners. +#[derive(Debug)] +pub enum PlannerInput<'a> { + /// Full router command before rewrite handling. + Command(&'a Command), +} + +/// Shared planner context exposing parser metadata. +pub struct PlannerContext<'a, 'b> { + context: &'a QueryParserContext<'b>, +} + +impl<'a, 'b> PlannerContext<'a, 'b> { + pub fn new(context: &'a QueryParserContext<'b>) -> Self { + Self { context } + } + + pub fn parser_context(&self) -> &'a QueryParserContext<'b> { + self.context + } +} + +pub trait RewritePlanner { + fn plan( + &self, + _planner_context: &PlannerContext<'_, '_>, + _input: &PlannerInput<'_>, + ) -> Result, ParserError>; +} + +pub struct RewriteRegistry { + planners: Vec>, +} + +impl RewriteRegistry { + pub fn new() -> Self { + Self { + planners: Vec::new(), + } + } + + pub fn add_planner(&mut self, planner: Box) { + self.planners.push(planner); + } + + pub fn plan( + &self, + context: &PlannerContext<'_, '_>, + input: &PlannerInput<'_>, + ) -> Result, ParserError> { + for planner in &self.planners { + if let Some(plan) = planner.plan(context, input)? { + return Ok(Some(plan)); + } + } + + Ok(None) + } +} + +impl Default for RewriteRegistry { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for RewriteRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RewriteRegistry").finish() + } +}