Skip to content

Commit

Permalink
refactor: add origin information to Column
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Sep 18, 2024
1 parent 2692f3c commit 6cc3ac0
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 0 deletions.
49 changes: 49 additions & 0 deletions sqlx-core/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::database::Database;
use crate::error::Error;

use std::fmt::Debug;
use std::sync::Arc;

pub trait Column: 'static + Send + Sync + Debug {
type Database: Database<Column = Self>;
Expand All @@ -20,6 +21,54 @@ pub trait Column: 'static + Send + Sync + Debug {

/// Gets the type information for the column.
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;

/// If this column comes from a table, return the table and original column name.
///
/// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression
/// or else the source table could not be determined.
///
/// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information,
/// or has not overridden this method.
fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown }
}

/// A [`Column`] that originates from a table.
#[derive(Debug, Clone)]
pub struct TableColumn {
/// The name of the table (optionally schema-qualified) that the column comes from.
pub table: Arc<str>,
/// The original name of the column.
pub name: Arc<str>,
}

/// The possible statuses for our knowledge of the origin of a [`Column`].
#[derive(Debug, Clone)]
pub enum ColumnOrigin {
/// The column is known to originate from a table.
///
/// Included is the table name and original column name.
Table(TableColumn),
/// The column originates from an expression, or else its origin could not be determined.
Expression,
/// The database driver does not know the column origin at this time.
///
/// This may happen if:
/// * The connection is in the middle of executing a query,
/// and cannot query the catalog to fetch this information.
/// * The connection does not have access to the database catalog.
/// * The implementation of [`Column`] did not override [`Column::origin()`].
Unknown,
}

impl ColumnOrigin {
/// Returns the true column origin, if known.
pub fn table_column(&self) -> Option<&TableColumn> {
if let Self::Table(table_column) = self {
Some(table_column)
} else {
None
}
}
}

/// A type that can be used to index into a [`Row`] or [`Statement`].
Expand Down
7 changes: 7 additions & 0 deletions sqlx-postgres/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::ext::ustr::UStr;
use crate::{PgTypeInfo, Postgres};

pub(crate) use sqlx_core::column::{Column, ColumnIndex};
use sqlx_core::column::ColumnOrigin;

#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
Expand All @@ -13,6 +14,8 @@ pub struct PgColumn {
pub(crate) relation_id: Option<crate::types::Oid>,
#[cfg_attr(feature = "offline", serde(skip))]
pub(crate) relation_attribute_no: Option<i16>,
#[cfg_attr(feature = "offline", serde(skip))]
pub(crate) origin: ColumnOrigin,
}

impl PgColumn {
Expand Down Expand Up @@ -51,4 +54,8 @@ impl Column for PgColumn {
fn type_info(&self) -> &PgTypeInfo {
&self.type_info
}

fn origin(&self) -> ColumnOrigin {
self.origin.clone()
}
}
59 changes: 59 additions & 0 deletions sqlx-postgres/src/connection/describe.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::btree_map;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::StatementId;
Expand All @@ -14,6 +15,9 @@ use futures_core::future::BoxFuture;
use smallvec::SmallVec;
use sqlx_core::query_builder::QueryBuilder;
use std::sync::Arc;
use sqlx_core::column::{ColumnOrigin, TableColumn};
use sqlx_core::hash_map;
use crate::connection::TableColumns;

/// Describes the type of the `pg_type.typtype` column
///
Expand Down Expand Up @@ -122,13 +126,20 @@ impl PgConnection {
let type_info = self
.maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
.await?;

let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) {
self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await?
} else {
ColumnOrigin::Expression
};

let column = PgColumn {
ordinal: index,
name: name.clone(),
type_info,
relation_id: field.relation_id,
relation_attribute_no: field.relation_attribute_no,
origin,
};

columns.push(column);
Expand Down Expand Up @@ -188,6 +199,54 @@ impl PgConnection {
Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
}
}

async fn maybe_fetch_column_origin(
&mut self,
relation_id: Oid,
attribute_no: i16,
should_fetch: bool,
) -> Result<ColumnOrigin, Error> {
let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) {
hash_map::Entry::Occupied(table_columns) => {
table_columns.into_mut()
},
hash_map::Entry::Vacant(vacant) => {
if !should_fetch { return Ok(ColumnOrigin::Unknown); }

let table_name: String = query_scalar("SELECT $1::oid::regclass::text")
.bind(relation_id)
.fetch_one(&mut *self)
.await?;

vacant.insert(TableColumns {
table_name: table_name.into(),
columns: Default::default(),
})
}
};

let column_name = match table_columns.columns.entry(attribute_no) {
btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()),
btree_map::Entry::Vacant(vacant) => {
if !should_fetch { return Ok(ColumnOrigin::Unknown); }

let column_name: String = query_scalar(
"SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2"
)
.bind(relation_id)
.bind(attribute_no)
.fetch_one(&mut *self)
.await?;

Arc::clone(vacant.insert(column_name.into()))
}
};

Ok(ColumnOrigin::Table(TableColumn {
table: table_columns.table_name.clone(),
name: column_name
}))
}

fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
Expand Down
8 changes: 8 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

Expand Down Expand Up @@ -57,6 +58,7 @@ pub struct PgConnection {
cache_type_info: HashMap<Oid, PgTypeInfo>,
cache_type_oid: HashMap<UStr, Oid>,
cache_elem_type_to_array: HashMap<Oid, Oid>,
cache_table_to_column_names: HashMap<Oid, TableColumns>,

// number of ReadyForQuery messages that we are currently expecting
pub(crate) pending_ready_for_query_count: usize,
Expand All @@ -68,6 +70,12 @@ pub struct PgConnection {
log_settings: LogSettings,
}

pub(crate) struct TableColumns {
table_name: Arc<str>,
/// Attribute number -> name.
columns: BTreeMap<i16, Arc<str>>,
}

impl PgConnection {
/// the version number of the server in `libpq` format
pub fn server_version_num(&self) -> Option<u32> {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub use sqlx_core::acquire::Acquire;
pub use sqlx_core::arguments::{Arguments, IntoArguments};
pub use sqlx_core::column::Column;
pub use sqlx_core::column::ColumnIndex;
pub use sqlx_core::column::ColumnOrigin;
pub use sqlx_core::connection::{ConnectOptions, Connection};
pub use sqlx_core::database::{self, Database};
pub use sqlx_core::describe::Describe;
Expand Down

0 comments on commit 6cc3ac0

Please sign in to comment.