Skip to content

Commit

Permalink
sql: clean up execution module
Browse files Browse the repository at this point in the history
  • Loading branch information
erikgrinaker committed Jul 21, 2024
1 parent 5747a71 commit ea1889a
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 104 deletions.
62 changes: 23 additions & 39 deletions src/sql/execution/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use itertools::Itertools as _;
use std::collections::BTreeMap;

/// Aggregates row values from the source according to the aggregates, using the
/// group_by expressions as buckets.
pub(super) fn aggregate(
/// group_by expressions as buckets. Emits rows with group_by buckets then
/// aggregates in the given order.
pub fn aggregate(
mut source: Rows,
group_by: Vec<Expression>,
aggregates: Vec<Aggregate>,
) -> Result<Rows> {
let mut aggregator = Aggregator::new(aggregates, group_by);
let mut aggregator = Aggregator::new(group_by, aggregates);
while let Some(row) = source.next().transpose()? {
aggregator.add(row)?;
}
Expand All @@ -23,39 +24,38 @@ pub(super) fn aggregate(
struct Aggregator {
/// Bucketed accumulators (by group_by values).
buckets: BTreeMap<Vec<Value>, Vec<Accumulator>>,
/// The set of empty accumulators.
/// The set of empty accumulators. Used to create new buckets.
empty: Vec<Accumulator>,
/// Expressions to accumulate. Indexes map to accumulators.
exprs: Vec<Expression>,
/// Group by expressions. Indexes map to bucket values.
group_by: Vec<Expression>,
/// Expressions to accumulate. Indexes map to accumulators.
expressions: Vec<Expression>,
}

impl Aggregator {
/// Creates a new aggregator for the given aggregates and GROUP BY buckets.
fn new(aggregates: Vec<Aggregate>, group_by: Vec<Expression>) -> Self {
/// Creates a new aggregator for the given GROUP BY buckets and aggregates.
fn new(group_by: Vec<Expression>, aggregates: Vec<Aggregate>) -> Self {
use Aggregate::*;
let accumulators = aggregates.iter().map(Accumulator::new).collect();
let exprs = aggregates
let expressions = aggregates
.into_iter()
.map(|aggregate| match aggregate {
Average(expr) | Count(expr) | Max(expr) | Min(expr) | Sum(expr) => expr,
})
.collect();
Self { buckets: BTreeMap::new(), empty: accumulators, group_by, exprs }
Self { buckets: BTreeMap::new(), empty: accumulators, group_by, expressions }
}

/// Adds a row to the aggregator.
fn add(&mut self, row: Row) -> Result<()> {
// Compute the bucket value.
let bucket: Vec<Value> =
self.group_by.iter().map(|expr| expr.evaluate(Some(&row))).collect::<Result<_>>()?;
self.group_by.iter().map(|expr| expr.evaluate(Some(&row))).try_collect()?;

// Compute and accumulate the input values.
let accumulators = self.buckets.entry(bucket).or_insert_with(|| self.empty.clone());
for (accumulator, expr) in accumulators.iter_mut().zip(&self.exprs) {
let value = expr.evaluate(Some(&row))?;
accumulator.add(value)?;
for (accumulator, expr) in accumulators.iter_mut().zip(&self.expressions) {
accumulator.add(expr.evaluate(Some(&row))?)?;
}
Ok(())
}
Expand Down Expand Up @@ -108,35 +108,19 @@ impl Accumulator {

/// Adds a value to the accumulator.
fn add(&mut self, value: Value) -> Result<()> {
use std::cmp::Ordering;

// NULL values are ignored in aggregates.
// Aggregates ignore NULL values.
if value == Value::Null {
return Ok(());
}

match self {
Self::Average { sum, count } => {
*sum = sum.checked_add(&value)?;
*count += 1;
}

Self::Count(c) => *c += 1,

Self::Average { sum, count } => (*sum, *count) = (sum.checked_add(&value)?, *count + 1),
Self::Count(count) => *count += 1,
Self::Max(max @ None) => *max = Some(value),
Self::Max(Some(max)) => {
if value.cmp(max) == Ordering::Greater {
*max = value
}
}

Self::Max(Some(max)) if value > *max => *max = value,
Self::Max(Some(_)) => {}
Self::Min(min @ None) => *min = Some(value),
Self::Min(Some(min)) => {
if value.cmp(min) == Ordering::Less {
*min = value
}
}

Self::Min(Some(min)) if value < *min => *min = value,
Self::Min(Some(_)) => {}
Self::Sum(sum @ None) => *sum = Some(Value::Integer(0).checked_add(&value)?),
Self::Sum(Some(sum)) => *sum = sum.checked_add(&value)?,
}
Expand All @@ -148,9 +132,9 @@ impl Accumulator {
Ok(match self {
Self::Average { count: 0, sum: _ } => Value::Null,
Self::Average { count, sum } => sum.checked_div(&Value::Integer(count))?,
Self::Count(c) => c.into(),
Self::Count(count) => count.into(),
Self::Max(Some(value)) | Self::Min(Some(value)) | Self::Sum(Some(value)) => value,
Self::Max(None) | Self::Min(None) | Self::Sum(None) => Value::Null,
Self::Max(Some(v)) | Self::Min(Some(v)) | Self::Sum(Some(v)) => v,
})
}
}
29 changes: 24 additions & 5 deletions src/sql/execution/execute.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use super::aggregate;
use super::join;
use super::source;
use super::transform;
use super::write;
use super::{aggregate, join, source, transform, write};
use crate::error::Result;
use crate::sql::engine::{Catalog, Transaction};
use crate::sql::planner::{Node, Plan};
Expand Down Expand Up @@ -56,6 +52,29 @@ pub fn execute_plan(
}

/// Recursively executes a query plan node, returning a row iterator.
///
/// Rows stream through the plan node tree from the branches to the root. Nodes
/// recursively pull input rows upwards from their child node(s), process them,
/// and hand the resulting rows off to their parent node.
///
/// Below is an example of an (unoptimized) query plan:
///
/// SELECT title, released, genres.name AS genre
/// FROM movies INNER JOIN genres ON movies.genre_id = genres.id
/// WHERE released >= 2000 ORDER BY released
///
/// Order: movies.released desc
/// └─ Projection: movies.title, movies.released, genres.name as genre
/// └─ Filter: movies.released >= 2000
/// └─ NestedLoopJoin: inner on movies.genre_id = genres.id
/// ├─ Scan: movies
/// └─ Scan: genres
///
/// Rows flow from the tree leaves to the root. The Scan nodes read and emit
/// table rows from storage. They are passed to the NestedLoopJoin node which
/// joins the rows from the two tables, then the Filter node discards old
/// movies, the Projection node picks out the requested columns, and the Order
/// node sorts them before emitting the rows to the client.
pub fn execute(node: Node, txn: &impl Transaction) -> Result<Rows> {
Ok(match node {
Node::Aggregate { source, group_by, aggregates } => {
Expand Down
15 changes: 8 additions & 7 deletions src/sql/execution/join.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::errdata;
use crate::errinput;
use crate::error::Result;
use crate::sql::types::{Expression, Row, Rows, Value};

Expand All @@ -11,7 +11,7 @@ use std::iter::Peekable;
/// there are no matches in the right source for a row in the left source, a
/// joined row with NULL values for the right source is returned (typically used
/// for a LEFT JOIN).
pub(super) fn nested_loop(
pub fn nested_loop(
left: Rows,
right: Rows,
right_size: usize,
Expand Down Expand Up @@ -57,7 +57,7 @@ impl NestedLoopIterator {
Ok(Self { left, right, right_init, right_size, right_match: false, predicate, outer })
}

// Returns the next joined row, if any, with error handling.
// Returns the next joined row, if any.
fn try_next(&mut self) -> Result<Option<Row>> {
// While there is a valid left row, look for a right-hand match to return.
while let Some(Ok(left_row)) = self.left.peek() {
Expand All @@ -70,7 +70,7 @@ impl NestedLoopIterator {
Some(predicate) => match predicate.evaluate(Some(&row))? {
Value::Boolean(true) => true,
Value::Boolean(false) | Value::Null => false,
v => return errdata!("join predicate returned {v}, expected boolean"),
v => return errinput!("join predicate returned {v}, expected boolean"),
},
None => true,
};
Expand Down Expand Up @@ -118,7 +118,7 @@ impl Iterator for NestedLoopIterator {
/// matching rows in the hash table. If outer is true, and there is no match
/// in the right source for a row in the left source, a row with NULL values
/// for the right source is emitted instead.
pub(super) fn hash(
pub fn hash(
left: Rows,
left_column: usize,
right: Rows,
Expand All @@ -141,7 +141,7 @@ pub(super) fn hash(
let empty = std::iter::repeat(Value::Null).take(right_size);

// Set up the join iterator.
Ok(Box::new(left.flat_map(move |result| -> Rows {
let join = left.flat_map(move |result| -> Rows {
// Pass through errors.
let Ok(row) = result else {
return Box::new(std::iter::once(result));
Expand All @@ -159,5 +159,6 @@ pub(super) fn hash(
}
None => Box::new(std::iter::empty()),
}
})))
});
Ok(Box::new(join))
}
2 changes: 2 additions & 0 deletions src/sql/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Executes a `Plan` against a `sql::engine::Engine`.

mod aggregate;
mod execute;
mod join;
Expand Down
14 changes: 5 additions & 9 deletions src/sql/execution/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,17 @@ use crate::sql::engine::Transaction;
use crate::sql::types::{Expression, Rows, Table, Value};

/// A table scan source.
pub(super) fn scan(
txn: &impl Transaction,
table: Table,
filter: Option<Expression>,
) -> Result<Rows> {
pub fn scan(txn: &impl Transaction, table: Table, filter: Option<Expression>) -> Result<Rows> {
Ok(Box::new(txn.scan(&table.name, filter)?))
}

/// A primary key lookup source.
pub(super) fn lookup_key(txn: &impl Transaction, table: String, keys: Vec<Value>) -> Result<Rows> {
pub fn lookup_key(txn: &impl Transaction, table: String, keys: Vec<Value>) -> Result<Rows> {
Ok(Box::new(txn.get(&table, &keys)?.into_iter().map(Ok)))
}

/// An index lookup source.
pub(super) fn lookup_index(
pub fn lookup_index(
txn: &impl Transaction,
table: String,
column: String,
Expand All @@ -28,11 +24,11 @@ pub(super) fn lookup_index(
}

/// Returns nothing. Used to short-circuit nodes that can't produce any rows.
pub(super) fn nothing() -> Rows {
pub fn nothing() -> Rows {
Box::new(std::iter::empty())
}

/// Emits predefined constant values.
pub(super) fn values(rows: Vec<Vec<Expression>>) -> Rows {
pub fn values(rows: Vec<Vec<Expression>>) -> Rows {
Box::new(rows.into_iter().map(|row| row.into_iter().map(|expr| expr.evaluate(None)).collect()))
}
46 changes: 20 additions & 26 deletions src/sql/execution/transform.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,42 @@
use itertools::izip;

use crate::errinput;
use crate::error::Result;
use crate::sql::planner::Direction;
use crate::sql::types::{Expression, Rows, Value};

use itertools::{izip, Itertools as _};

/// Filters the input rows (i.e. WHERE).
pub(super) fn filter(source: Rows, predicate: Expression) -> Rows {
pub fn filter(source: Rows, predicate: Expression) -> Rows {
Box::new(source.filter_map(move |r| {
r.and_then(|row| match predicate.evaluate(Some(&row))? {
Value::Boolean(true) => Ok(Some(row)),
Value::Boolean(false) => Ok(None),
Value::Null => Ok(None),
Value::Boolean(false) | Value::Null => Ok(None),
value => errinput!("filter returned {value}, expected boolean",),
})
.transpose()
}))
}

/// Limits the result to the given number of rows (i.e. LIMIT).
pub(super) fn limit(source: Rows, limit: usize) -> Rows {
pub fn limit(source: Rows, limit: usize) -> Rows {
Box::new(source.take(limit))
}

/// Skips the given number of rows (i.e. OFFSET).
pub(super) fn offset(source: Rows, offset: usize) -> Rows {
pub fn offset(source: Rows, offset: usize) -> Rows {
Box::new(source.skip(offset))
}

/// Sorts the rows (i.e. ORDER BY).
pub(super) fn order(source: Rows, order: Vec<(Expression, Direction)>) -> Result<Rows> {
pub fn order(source: Rows, order: Vec<(Expression, Direction)>) -> Result<Rows> {
// We can't use sort_by_cached_key(), since expression evaluation is
// fallible, and since we may have to vary the sort direction of each
// expression. Precompute the sort values instead, and map them based on
// the row index.
let mut irows: Vec<_> =
source.enumerate().map(|(i, r)| r.map(|row| (i, row))).collect::<Result<_>>()?;

let mut irows: Vec<_> = source.enumerate().map(|(i, r)| r.map(|row| (i, row))).try_collect()?;
let mut sort_values = Vec::with_capacity(irows.len());
for (_, row) in &irows {
let values: Vec<_> =
order.iter().map(|(e, _)| e.evaluate(Some(row))).collect::<Result<_>>()?;
let values: Vec<_> = order.iter().map(|(e, _)| e.evaluate(Some(row))).try_collect()?;
sort_values.push(values)
}

Expand All @@ -60,24 +56,22 @@ pub(super) fn order(source: Rows, order: Vec<(Expression, Direction)>) -> Result
}

/// Projects the rows using the given expressions (i.e. SELECT).
pub(super) fn project(source: Rows, expressions: Vec<Expression>) -> Rows {
Box::new(source.map(move |r| {
r.and_then(|row| expressions.iter().map(|e| e.evaluate(Some(&row))).collect())
pub fn project(source: Rows, expressions: Vec<Expression>) -> Rows {
Box::new(source.map(move |result| {
result.and_then(|row| expressions.iter().map(|e| e.evaluate(Some(&row))).collect())
}))
}

/// Remaps source columns to target column indexes, or drops them if None.
pub(super) fn remap(source: Rows, targets: Vec<Option<usize>>) -> Rows {
pub fn remap(source: Rows, targets: Vec<Option<usize>>) -> Rows {
let size = targets.iter().filter_map(|v| *v).map(|i| i + 1).max().unwrap_or(0);
Box::new(source.map(move |r| {
r.map(|row| {
let mut out = vec![Value::Null; size];
for (value, target) in row.into_iter().zip(&targets) {
if let Some(index) = target {
out[*index] = value;
}
Box::new(source.map_ok(move |row| {
let mut out = vec![Value::Null; size];
for (value, target) in row.into_iter().zip(&targets) {
if let Some(index) = target {
out[*index] = value;
}
out
})
}
out
}))
}
Loading

0 comments on commit ea1889a

Please sign in to comment.