Skip to content

Commit

Permalink
refactor: add origin information to Column
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Sep 18, 2024
1 parent 25c755b commit 3e812b6
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 7 deletions.
54 changes: 54 additions & 0 deletions sqlx-core/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::database::Database;
use crate::error::Error;

use std::fmt::Debug;
use std::sync::Arc;

pub trait Column: 'static + Send + Sync + Debug {
type Database: Database<Column = Self>;
Expand All @@ -20,6 +21,59 @@ pub trait Column: 'static + Send + Sync + Debug {

/// Gets the type information for the column.
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;

/// If this column comes from a table, return the table and original column name.
///
/// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression
/// or else the source table could not be determined.
///
/// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information,
/// or has not overridden this method.
// This method returns an owned value instead of a reference,
// to give the implementor more flexibility.
fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown }
}

/// A [`Column`] that originates from a table.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct TableColumn {
/// The name of the table (optionally schema-qualified) that the column comes from.
pub table: Arc<str>,
/// The original name of the column.
pub name: Arc<str>,
}

/// The possible statuses for our knowledge of the origin of a [`Column`].
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub enum ColumnOrigin {
/// The column is known to originate from a table.
///
/// Included is the table name and original column name.
Table(TableColumn),
/// The column originates from an expression, or else its origin could not be determined.
Expression,
/// The database driver does not know the column origin at this time.
///
/// This may happen if:
/// * The connection is in the middle of executing a query,
/// and cannot query the catalog to fetch this information.
/// * The connection does not have access to the database catalog.
/// * The implementation of [`Column`] did not override [`Column::origin()`].
#[default]
Unknown,
}

impl ColumnOrigin {
/// Returns the true column origin, if known.
pub fn table_column(&self) -> Option<&TableColumn> {
if let Self::Table(table_column) = self {
Some(table_column)
} else {
None
}
}
}

/// A type that can be used to index into a [`Row`] or [`Statement`].
Expand Down
7 changes: 7 additions & 0 deletions sqlx-mysql/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub struct MySqlColumn {
pub(crate) name: UStr,
pub(crate) type_info: MySqlTypeInfo,

#[cfg_attr(feature = "offline", serde(default))]
pub(crate) origin: ColumnOrigin,

#[cfg_attr(feature = "offline", serde(skip))]
pub(crate) flags: Option<ColumnFlags>,
}
Expand All @@ -28,4 +31,8 @@ impl Column for MySqlColumn {
fn type_info(&self) -> &MySqlTypeInfo {
&self.type_info
}

fn origin(&self) -> ColumnOrigin {
self.origin.clone()
}
}
21 changes: 21 additions & 0 deletions sqlx-mysql/src/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use futures_core::stream::BoxStream;
use futures_core::Stream;
use futures_util::{pin_mut, TryStreamExt};
use std::{borrow::Cow, sync::Arc};
use sqlx_core::column::{ColumnOrigin, TableColumn};

impl MySqlConnection {
async fn prepare_statement<'c>(
Expand Down Expand Up @@ -382,18 +383,38 @@ async fn recv_result_columns(
fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MySqlColumn, Error> {
// if the alias is empty, use the alias
// only then use the name
let column_name = def.name()?;

let name = match (def.name()?, def.alias()?) {
(_, alias) if !alias.is_empty() => UStr::new(alias),
(name, _) => UStr::new(name),
};

let table = def.table()?;

let origin = if table.is_empty() {
ColumnOrigin::Expression
} else {
let schema = def.schema()?;

ColumnOrigin::Table(TableColumn {
table: if !schema.is_empty() {
format!("{schema}.{table}").into()
} else {
table.into()
},
name: column_name.into(),
})
};

let type_info = MySqlTypeInfo::from_column(def);

Ok(MySqlColumn {
name,
type_info,
ordinal,
flags: Some(def.flags),
origin,
})
}

Expand Down
16 changes: 11 additions & 5 deletions sqlx-mysql/src/protocol/text/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::str::from_utf8;
use std::str;

use bitflags::bitflags;
use bytes::{Buf, Bytes};
Expand Down Expand Up @@ -104,11 +104,9 @@ pub enum ColumnType {
pub(crate) struct ColumnDefinition {
#[allow(unused)]
catalog: Bytes,
#[allow(unused)]
schema: Bytes,
#[allow(unused)]
table_alias: Bytes,
#[allow(unused)]
table: Bytes,
alias: Bytes,
name: Bytes,
Expand All @@ -125,12 +123,20 @@ impl ColumnDefinition {
// NOTE: strings in-protocol are transmitted according to the client character set
// as this is UTF-8, all these strings should be UTF-8

pub(crate) fn schema(&self) -> Result<&str, Error> {
str::from_utf8(&self.schema).map_err(Error::protocol)
}

pub(crate) fn table(&self) -> Result<&str, Error> {
str::from_utf8(&self.table).map_err(Error::protocol)
}

pub(crate) fn name(&self) -> Result<&str, Error> {
from_utf8(&self.name).map_err(Error::protocol)
str::from_utf8(&self.name).map_err(Error::protocol)
}

pub(crate) fn alias(&self) -> Result<&str, Error> {
from_utf8(&self.alias).map_err(Error::protocol)
str::from_utf8(&self.alias).map_err(Error::protocol)
}
}

Expand Down
9 changes: 9 additions & 0 deletions sqlx-postgres/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ use crate::ext::ustr::UStr;
use crate::{PgTypeInfo, Postgres};

pub(crate) use sqlx_core::column::{Column, ColumnIndex};
use sqlx_core::column::ColumnOrigin;

#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct PgColumn {
pub(crate) ordinal: usize,
pub(crate) name: UStr,
pub(crate) type_info: PgTypeInfo,

#[cfg_attr(feature = "offline", serde(default))]
pub(crate) origin: ColumnOrigin,

#[cfg_attr(feature = "offline", serde(skip))]
pub(crate) relation_id: Option<crate::types::Oid>,
#[cfg_attr(feature = "offline", serde(skip))]
Expand Down Expand Up @@ -51,4 +56,8 @@ impl Column for PgColumn {
fn type_info(&self) -> &PgTypeInfo {
&self.type_info
}

fn origin(&self) -> ColumnOrigin {
self.origin.clone()
}
}
59 changes: 59 additions & 0 deletions sqlx-postgres/src/connection/describe.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::btree_map;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::StatementId;
Expand All @@ -14,6 +15,9 @@ use futures_core::future::BoxFuture;
use smallvec::SmallVec;
use sqlx_core::query_builder::QueryBuilder;
use std::sync::Arc;
use sqlx_core::column::{ColumnOrigin, TableColumn};
use sqlx_core::hash_map;
use crate::connection::TableColumns;

/// Describes the type of the `pg_type.typtype` column
///
Expand Down Expand Up @@ -122,13 +126,20 @@ impl PgConnection {
let type_info = self
.maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
.await?;

let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) {
self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await?
} else {
ColumnOrigin::Expression
};

let column = PgColumn {
ordinal: index,
name: name.clone(),
type_info,
relation_id: field.relation_id,
relation_attribute_no: field.relation_attribute_no,
origin,
};

columns.push(column);
Expand Down Expand Up @@ -188,6 +199,54 @@ impl PgConnection {
Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
}
}

async fn maybe_fetch_column_origin(
&mut self,
relation_id: Oid,
attribute_no: i16,
should_fetch: bool,
) -> Result<ColumnOrigin, Error> {
let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) {
hash_map::Entry::Occupied(table_columns) => {
table_columns.into_mut()
},
hash_map::Entry::Vacant(vacant) => {
if !should_fetch { return Ok(ColumnOrigin::Unknown); }

let table_name: String = query_scalar("SELECT $1::oid::regclass::text")
.bind(relation_id)
.fetch_one(&mut *self)
.await?;

vacant.insert(TableColumns {
table_name: table_name.into(),
columns: Default::default(),
})
}
};

let column_name = match table_columns.columns.entry(attribute_no) {
btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()),
btree_map::Entry::Vacant(vacant) => {
if !should_fetch { return Ok(ColumnOrigin::Unknown); }

let column_name: String = query_scalar(
"SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2"
)
.bind(relation_id)
.bind(attribute_no)
.fetch_one(&mut *self)
.await?;

Arc::clone(vacant.insert(column_name.into()))
}
};

Ok(ColumnOrigin::Table(TableColumn {
table: table_columns.table_name.clone(),
name: column_name
}))
}

fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
Expand Down
8 changes: 8 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;

Expand Down Expand Up @@ -57,6 +58,7 @@ pub struct PgConnection {
cache_type_info: HashMap<Oid, PgTypeInfo>,
cache_type_oid: HashMap<UStr, Oid>,
cache_elem_type_to_array: HashMap<Oid, Oid>,
cache_table_to_column_names: HashMap<Oid, TableColumns>,

// number of ReadyForQuery messages that we are currently expecting
pub(crate) pending_ready_for_query_count: usize,
Expand All @@ -68,6 +70,12 @@ pub struct PgConnection {
log_settings: LogSettings,
}

pub(crate) struct TableColumns {
table_name: Arc<str>,
/// Attribute number -> name.
columns: BTreeMap<i16, Arc<str>>,
}

impl PgConnection {
/// the version number of the server in `libpq` format
pub fn server_version_num(&self) -> Option<u32> {
Expand Down
7 changes: 7 additions & 0 deletions sqlx-sqlite/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pub struct SqliteColumn {
pub(crate) name: UStr,
pub(crate) ordinal: usize,
pub(crate) type_info: SqliteTypeInfo,

#[cfg_attr(feature = "offline", serde(default))]
pub(crate) origin: ColumnOrigin
}

impl Column for SqliteColumn {
Expand All @@ -25,4 +28,8 @@ impl Column for SqliteColumn {
fn type_info(&self) -> &SqliteTypeInfo {
&self.type_info
}

fn origin(&self) -> ColumnOrigin {
self.origin.clone()
}
}
3 changes: 3 additions & 0 deletions sqlx-sqlite/src/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri

for col in 0..num {
let name = stmt.handle.column_name(col).to_owned();

let origin = stmt.handle.column_origin(col);

let type_info = if let Some(ty) = stmt.handle.column_decltype(col) {
ty
Expand Down Expand Up @@ -82,6 +84,7 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri
name: name.into(),
type_info,
ordinal: col,
origin,
});
}
}
Expand Down
Loading

0 comments on commit 3e812b6

Please sign in to comment.