From 6cc3ac03bb288c4f2dfccec274008720f259f4b4 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 18 Sep 2024 01:55:59 -0700 Subject: [PATCH] refactor: add origin information to `Column` --- sqlx-core/src/column.rs | 49 ++++++++++++++++++++ sqlx-postgres/src/column.rs | 7 +++ sqlx-postgres/src/connection/describe.rs | 59 ++++++++++++++++++++++++ sqlx-postgres/src/connection/mod.rs | 8 ++++ src/lib.rs | 1 + 5 files changed, 124 insertions(+) diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index 9f45819ed6..4cf0166e1e 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -2,6 +2,7 @@ use crate::database::Database; use crate::error::Error; use std::fmt::Debug; +use std::sync::Arc; pub trait Column: 'static + Send + Sync + Debug { type Database: Database; @@ -20,6 +21,54 @@ pub trait Column: 'static + Send + Sync + Debug { /// Gets the type information for the column. fn type_info(&self) -> &::TypeInfo; + + /// If this column comes from a table, return the table and original column name. + /// + /// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression + /// or else the source table could not be determined. + /// + /// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information, + /// or has not overridden this method. + fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown } +} + +/// A [`Column`] that originates from a table. +#[derive(Debug, Clone)] +pub struct TableColumn { + /// The name of the table (optionally schema-qualified) that the column comes from. + pub table: Arc, + /// The original name of the column. + pub name: Arc, +} + +/// The possible statuses for our knowledge of the origin of a [`Column`]. +#[derive(Debug, Clone)] +pub enum ColumnOrigin { + /// The column is known to originate from a table. + /// + /// Included is the table name and original column name. + Table(TableColumn), + /// The column originates from an expression, or else its origin could not be determined. + Expression, + /// The database driver does not know the column origin at this time. + /// + /// This may happen if: + /// * The connection is in the middle of executing a query, + /// and cannot query the catalog to fetch this information. + /// * The connection does not have access to the database catalog. + /// * The implementation of [`Column`] did not override [`Column::origin()`]. + Unknown, +} + +impl ColumnOrigin { + /// Returns the true column origin, if known. + pub fn table_column(&self) -> Option<&TableColumn> { + if let Self::Table(table_column) = self { + Some(table_column) + } else { + None + } + } } /// A type that can be used to index into a [`Row`] or [`Statement`]. diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs index a838c27b75..fdc5793a78 100644 --- a/sqlx-postgres/src/column.rs +++ b/sqlx-postgres/src/column.rs @@ -2,6 +2,7 @@ use crate::ext::ustr::UStr; use crate::{PgTypeInfo, Postgres}; pub(crate) use sqlx_core::column::{Column, ColumnIndex}; +use sqlx_core::column::ColumnOrigin; #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] @@ -13,6 +14,8 @@ pub struct PgColumn { pub(crate) relation_id: Option, #[cfg_attr(feature = "offline", serde(skip))] pub(crate) relation_attribute_no: Option, + #[cfg_attr(feature = "offline", serde(skip))] + pub(crate) origin: ColumnOrigin, } impl PgColumn { @@ -51,4 +54,8 @@ impl Column for PgColumn { fn type_info(&self) -> &PgTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 9a46a202d5..e05b6fccf4 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -1,3 +1,4 @@ +use std::collections::btree_map; use crate::error::Error; use crate::ext::ustr::UStr; use crate::io::StatementId; @@ -14,6 +15,9 @@ use futures_core::future::BoxFuture; use smallvec::SmallVec; use sqlx_core::query_builder::QueryBuilder; use std::sync::Arc; +use sqlx_core::column::{ColumnOrigin, TableColumn}; +use sqlx_core::hash_map; +use crate::connection::TableColumns; /// Describes the type of the `pg_type.typtype` column /// @@ -122,6 +126,12 @@ impl PgConnection { let type_info = self .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) .await?; + + let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) { + self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await? + } else { + ColumnOrigin::Expression + }; let column = PgColumn { ordinal: index, @@ -129,6 +139,7 @@ impl PgConnection { type_info, relation_id: field.relation_id, relation_attribute_no: field.relation_attribute_no, + origin, }; columns.push(column); @@ -188,6 +199,54 @@ impl PgConnection { Ok(PgTypeInfo(PgType::DeclareWithOid(oid))) } } + + async fn maybe_fetch_column_origin( + &mut self, + relation_id: Oid, + attribute_no: i16, + should_fetch: bool, + ) -> Result { + let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) { + hash_map::Entry::Occupied(table_columns) => { + table_columns.into_mut() + }, + hash_map::Entry::Vacant(vacant) => { + if !should_fetch { return Ok(ColumnOrigin::Unknown); } + + let table_name: String = query_scalar("SELECT $1::oid::regclass::text") + .bind(relation_id) + .fetch_one(&mut *self) + .await?; + + vacant.insert(TableColumns { + table_name: table_name.into(), + columns: Default::default(), + }) + } + }; + + let column_name = match table_columns.columns.entry(attribute_no) { + btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()), + btree_map::Entry::Vacant(vacant) => { + if !should_fetch { return Ok(ColumnOrigin::Unknown); } + + let column_name: String = query_scalar( + "SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2" + ) + .bind(relation_id) + .bind(attribute_no) + .fetch_one(&mut *self) + .await?; + + Arc::clone(vacant.insert(column_name.into())) + } + }; + + Ok(ColumnOrigin::Table(TableColumn { + table: table_columns.table_name.clone(), + name: column_name + })) + } fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result> { Box::pin(async move { diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 5a6a597ead..0244b8eecf 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -57,6 +58,7 @@ pub struct PgConnection { cache_type_info: HashMap, cache_type_oid: HashMap, cache_elem_type_to_array: HashMap, + cache_table_to_column_names: HashMap, // number of ReadyForQuery messages that we are currently expecting pub(crate) pending_ready_for_query_count: usize, @@ -68,6 +70,12 @@ pub struct PgConnection { log_settings: LogSettings, } +pub(crate) struct TableColumns { + table_name: Arc, + /// Attribute number -> name. + columns: BTreeMap>, +} + impl PgConnection { /// the version number of the server in `libpq` format pub fn server_version_num(&self) -> Option { diff --git a/src/lib.rs b/src/lib.rs index e2fd0b1567..6615972a3a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; +pub use sqlx_core::column::ColumnOrigin; pub use sqlx_core::connection::{ConnectOptions, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe;