Skip to content

Commit

Permalink
sql: rewrite aggregate function and GROUP BY planning
Browse files Browse the repository at this point in the history
  • Loading branch information
erikgrinaker committed Jul 14, 2024
1 parent f4fac4f commit c696b1e
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 386 deletions.
10 changes: 8 additions & 2 deletions src/sql/execution/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@ struct Aggregator {
impl Aggregator {
/// Creates a new aggregator for the given aggregates and GROUP BY buckets.
fn new(aggregates: Vec<Aggregate>, group_by: Vec<Expression>) -> Self {
use Aggregate::*;
let accumulators = aggregates.iter().map(Accumulator::new).collect();
let exprs = aggregates.into_iter().map(|a| a.into_inner()).collect();
let exprs = aggregates
.into_iter()
.map(|aggregate| match aggregate {
Average(expr) | Count(expr) | Max(expr) | Min(expr) | Sum(expr) => expr,
})
.collect();
Self { buckets: BTreeMap::new(), empty: accumulators, group_by, exprs }
}

Expand Down Expand Up @@ -101,7 +107,7 @@ impl Accumulator {
}

/// Adds a value to the accumulator.
/// TODO: have this take &Value.
/// TODO: NULL values should possibly be ignored, not yield NULL (see Postgres?).
fn add(&mut self, value: Value) -> Result<()> {
use std::cmp::Ordering;
match (self, value) {
Expand Down
84 changes: 2 additions & 82 deletions src/sql/parser/ast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::error::Result;
use crate::sql::types::DataType;

use std::collections::BTreeMap;
Expand Down Expand Up @@ -101,9 +100,6 @@ pub enum Order {
pub enum Expression {
/// A field reference, with an optional table qualifier.
Field(Option<String>, String),
/// A column index (only used during planning to break off subtrees).
/// TODO: get rid of this, planning shouldn't modify the AST.
Column(usize),
/// A literal value.
Literal(Literal),
/// A function call (name and parameters).
Expand Down Expand Up @@ -188,82 +184,6 @@ pub enum Operator {
}

impl Expression {
/// Transforms the expression tree depth-first by applying a closure before
/// and after descending.
///
/// TODO: make closures non-mut.
pub fn transform<B, A>(mut self, before: &mut B, after: &mut A) -> Result<Self>
where
B: FnMut(Self) -> Result<Self>,
A: FnMut(Self) -> Result<Self>,
{
use Operator::*;
self = before(self)?;

// Helper for transforming a boxed expression.
let mut transform = |mut expr: Box<Expression>| -> Result<Box<Expression>> {
*expr = expr.transform(before, after)?;
Ok(expr)
};

self = match self {
Self::Literal(_) | Self::Field(_, _) | Self::Column(_) => self,

Self::Function(name, exprs) => Self::Function(
name,
exprs.into_iter().map(|e| e.transform(before, after)).collect::<Result<_>>()?,
),

Self::Operator(op) => Self::Operator(match op {
Add(lhs, rhs) => Add(transform(lhs)?, transform(rhs)?),
And(lhs, rhs) => And(transform(lhs)?, transform(rhs)?),
Divide(lhs, rhs) => Divide(transform(lhs)?, transform(rhs)?),
Equal(lhs, rhs) => Equal(transform(lhs)?, transform(rhs)?),
Exponentiate(lhs, rhs) => Exponentiate(transform(lhs)?, transform(rhs)?),
Factorial(expr) => Factorial(transform(expr)?),
GreaterThan(lhs, rhs) => GreaterThan(transform(lhs)?, transform(rhs)?),
GreaterThanOrEqual(lhs, rhs) => {
GreaterThanOrEqual(transform(lhs)?, transform(rhs)?)
}
Identity(expr) => Identity(transform(expr)?),
IsNaN(expr) => IsNaN(transform(expr)?),
IsNull(expr) => IsNull(transform(expr)?),
LessThan(lhs, rhs) => LessThan(transform(lhs)?, transform(rhs)?),
LessThanOrEqual(lhs, rhs) => LessThanOrEqual(transform(lhs)?, transform(rhs)?),
Like(lhs, rhs) => Like(transform(lhs)?, transform(rhs)?),
Modulo(lhs, rhs) => Modulo(transform(lhs)?, transform(rhs)?),
Multiply(lhs, rhs) => Multiply(transform(lhs)?, transform(rhs)?),
Negate(expr) => Negate(transform(expr)?),
Not(expr) => Not(transform(expr)?),
NotEqual(lhs, rhs) => NotEqual(transform(lhs)?, transform(rhs)?),
Or(lhs, rhs) => Or(transform(lhs)?, transform(rhs)?),
Subtract(lhs, rhs) => Subtract(transform(lhs)?, transform(rhs)?),
}),
};
self = after(self)?;
Ok(self)
}

/// Transforms an expression using a mutable reference.
/// TODO: try to get rid of this and replace_with().
pub fn transform_mut<B, A>(&mut self, before: &mut B, after: &mut A) -> Result<()>
where
B: FnMut(Self) -> Result<Self>,
A: FnMut(Self) -> Result<Self>,
{
self.replace_with(|e| e.transform(before, after))
}

/// Replaces the expression with result of the closure. Helper function for
/// transform().
fn replace_with(&mut self, mut f: impl FnMut(Self) -> Result<Self>) -> Result<()> {
// Temporarily replace expression with a null value, in case closure panics. May consider
// replace_with crate if this hampers performance.
let expr = std::mem::replace(self, Expression::Literal(Literal::Null));
*self = f(expr)?;
Ok(())
}

/// Walks the expression tree depth-first, calling a closure for every node.
/// Halts and returns false if the closure returns false.
pub fn walk(&self, visitor: &mut impl FnMut(&Expression) -> bool) -> bool {
Expand Down Expand Up @@ -297,7 +217,7 @@ impl Expression {

Self::Function(_, exprs) => exprs.iter().any(|expr| expr.walk(visitor)),

Self::Literal(_) | Self::Field(_, _) | Self::Column(_) => true,
Self::Literal(_) | Self::Field(_, _) => true,
}
}

Expand Down Expand Up @@ -348,7 +268,7 @@ impl Expression {

Self::Function(_, exprs) => exprs.iter().for_each(|expr| expr.collect(visitor, c)),

Self::Literal(_) | Self::Field(_, _) | Self::Column(_) => {}
Self::Literal(_) | Self::Field(_, _) => {}
}
}
}
Expand Down
23 changes: 3 additions & 20 deletions src/sql/planner/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ impl Plan {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Node {
/// Computes aggregate values for the given expressions and group_by buckets
/// across all rows in the source node.
/// across all rows in the source node. The aggregate columns are output
/// first, followed by the group_by columns, in the given order.
/// TODO: reverse the order.
Aggregate { source: Box<Node>, aggregates: Vec<Aggregate>, group_by: Vec<Expression> },
/// Filters source rows, by only emitting rows for which the predicate
/// evaluates to true.
Expand Down Expand Up @@ -465,25 +467,6 @@ impl std::fmt::Display for Aggregate {
}
}

impl Aggregate {
/// Returns the inner aggregate expression. Currently, all aggregate
/// functions take a single input expression.
pub fn into_inner(self) -> Expression {
match self {
Self::Average(expr)
| Self::Count(expr)
| Self::Max(expr)
| Self::Min(expr)
| Self::Sum(expr) => expr,
}
}

// TODO: get rid of this.
pub(super) fn is(name: &str) -> bool {
["avg", "count", "max", "min", "sum"].contains(&name)
}
}

/// A sort order direction.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Direction {
Expand Down
Loading

0 comments on commit c696b1e

Please sign in to comment.