Skip to content

Commit

Permalink
Query: Entity splitting support for regular entities (#28425)
Browse files Browse the repository at this point in the history
Part of #620
  • Loading branch information
smitpatel authored Jul 13, 2022
1 parent d4f0621 commit 43e0755
Show file tree
Hide file tree
Showing 22 changed files with 427 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ protected override LambdaExpression GenerateMaterializationCondition(IEntityType
return baseCondition;
}

var table = entityType.GetViewOrTableMappings().SingleOrDefault()?.Table
var table = entityType.GetViewOrTableMappings().SingleOrDefault(e => e.IsSplitEntityTypePrincipal ?? true)?.Table
?? entityType.GetDefaultMappings().Single().Table;
if (table.IsOptional(entityType))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
return propertyAccess;
}

var table = entityType.GetViewOrTableMappings().SingleOrDefault()?.Table
var table = entityType.GetViewOrTableMappings().SingleOrDefault(e => e.IsSplitEntityTypePrincipal ?? true)?.Table
?? entityType.GetDefaultMappings().Single().Table;
if (!table.IsOptional(entityType))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,26 +320,26 @@ private sealed class ColumnExpressionFindingExpressionVisitor : ExpressionVisito
// Always skip the table of ColumnExpression since it will traverse into deeper subquery
return columnExpression;

case LeftJoinExpression leftJoinExpression:
var leftJoinTableAlias = leftJoinExpression.Table.Alias!;
case PredicateJoinExpressionBase predicateJoinExpressionBase:
var predicateJoinTableAlias = predicateJoinExpressionBase.Table.Alias!;
// Visiting the join predicate will add some columns for join table.
// But if all the referenced columns are in join predicate only then we can remove the join table.
// So if there are no referenced columns yet means there is still potential to remove this table,
// In such case we moved the columns encountered in join predicate to other dictionary and later merge
// if there are more references to the join table outside of join predicate.
// We currently do this only for LeftJoin since that is the only predicate join table we remove.
// We should also remove references to the outer if this column gets removed then that subquery can also remove projections
// But currently we only remove table for TPT scenario in which there are all table expressions which connects via joins.
var joinOnSameLevel = _columnReferenced!.ContainsKey(leftJoinTableAlias);
var noReferences = !joinOnSameLevel || _columnReferenced[leftJoinTableAlias] == null;
base.Visit(leftJoinExpression);
// But currently we only remove table for TPT & entity splitting scenario
// in which there are all table expressions which connects via joins.
var joinOnSameLevel = _columnReferenced!.ContainsKey(predicateJoinTableAlias);
var noReferences = !joinOnSameLevel || _columnReferenced[predicateJoinTableAlias] == null;
base.Visit(predicateJoinExpressionBase);
if (noReferences && joinOnSameLevel)
{
_columnsUsedInJoinCondition![leftJoinTableAlias] = _columnReferenced[leftJoinTableAlias];
_columnReferenced[leftJoinTableAlias] = null;
_columnsUsedInJoinCondition![predicateJoinTableAlias] = _columnReferenced[predicateJoinTableAlias];
_columnReferenced[predicateJoinTableAlias] = null;
}

return leftJoinExpression;
return predicateJoinExpressionBase;

default:
return base.Visit(expression);
Expand Down Expand Up @@ -930,7 +930,7 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
};
newSelectExpression._mutable = selectExpression._mutable;

newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables);
newSelectExpression._removableJoinTables.AddRange(selectExpression._removableJoinTables);

foreach (var kvp in selectExpression._tpcDiscriminatorValues)
{
Expand Down
139 changes: 104 additions & 35 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public sealed partial class SelectExpression : TableExpressionBase

private readonly List<(ColumnExpression Column, ValueComparer Comparer)> _identifier = new();
private readonly List<(ColumnExpression Column, ValueComparer Comparer)> _childIdentifiers = new();
private readonly List<int> _tptLeftJoinTables = new();
private readonly List<int> _removableJoinTables = new();
private readonly Dictionary<TpcTablesExpression, (ColumnExpression, List<string>)> _tpcDiscriminatorValues
= new(ReferenceEqualityComparer.Instance);

Expand Down Expand Up @@ -165,7 +165,7 @@ internal SelectExpression(IEntityType entityType, ISqlExpressionFactory sqlExpre
.Aggregate((l, r) => sqlExpressionFactory.AndAlso(l, r));

var joinExpression = new LeftJoinExpression(tableExpression, joinPredicate);
_tptLeftJoinTables.Add(_tables.Count);
_removableJoinTables.Add(_tables.Count);
AddTable(joinExpression, tableReferenceExpression);
}

Expand All @@ -187,7 +187,7 @@ internal SelectExpression(IEntityType entityType, ISqlExpressionFactory sqlExpre
if (entityTypes.Length == 1)
{
// For single entity case, we don't need discriminator.
var table = GetTableBase(entityTypes[0]);
var table = entityTypes[0].GetViewOrTableMappings().Single().Table;
var tableExpression = new TableExpression(table);

var tableReferenceExpression = new TableReferenceExpression(this, tableExpression.Alias!);
Expand All @@ -212,7 +212,7 @@ internal SelectExpression(IEntityType entityType, ISqlExpressionFactory sqlExpre
}
else
{
var tables = entityTypes.Select(e => GetTableBase(e)).ToArray();
var tables = entityTypes.Select(e => e.GetViewOrTableMappings().Single().Table).ToArray();
var properties = GetAllPropertiesInHierarchy(entityType).ToArray();
var propertyNamesMap = new Dictionary<IProperty, string>();
for (var i = 0; i < entityTypes.Length; i++)
Expand Down Expand Up @@ -314,47 +314,106 @@ internal SelectExpression(IEntityType entityType, ISqlExpressionFactory sqlExpre
default:
{
// Also covers TPH
ITableBase table;
TableExpressionBase tableExpression;
if (entityType.GetFunctionMappings().SingleOrDefault(e => e.IsDefaultFunctionMapping) is IFunctionMapping functionMapping)
{
var storeFunction = functionMapping.Table;

table = storeFunction;
tableExpression = new TableValuedFunctionExpression((IStoreFunction)storeFunction, Array.Empty<SqlExpression>());
GenerateNonHierarchyNonSplittingEntityType(
storeFunction, new TableValuedFunctionExpression((IStoreFunction)storeFunction, Array.Empty<SqlExpression>()));
}
else
{
table = GetTableBase(entityType);
tableExpression = new TableExpression(table);
}
var mappings = entityType.GetViewOrTableMappings().ToList();
if (mappings.Count == 1)
{
var table = mappings[0].Table;

var tableReferenceExpression = new TableReferenceExpression(this, tableExpression.Alias!);
AddTable(tableExpression, tableReferenceExpression);
GenerateNonHierarchyNonSplittingEntityType(table, new TableExpression(table));
}
else
{
// entity splitting
var keyProperties = entityType.FindPrimaryKey()!.Properties;
List<ColumnExpression> joinColumns = default!;
var columns = new Dictionary<IProperty, ColumnExpression>();
var tableReferenceExpressionMap = new Dictionary<ITableBase, TableReferenceExpression>();
foreach (var mapping in mappings)
{
var table = mapping.Table;
var tableExpression = new TableExpression(table);
var tableReferenceExpression = new TableReferenceExpression(this, tableExpression.Alias);
tableReferenceExpressionMap[table] = tableReferenceExpression;

var propertyExpressions = new Dictionary<IProperty, ColumnExpression>();
foreach (var property in GetAllPropertiesInHierarchy(entityType))
{
propertyExpressions[property] = CreateColumnExpression(property, table, tableReferenceExpression, nullable: false);
}
if (_tables.Count == 0)
{
AddTable(tableExpression, tableReferenceExpression);
joinColumns = new List<ColumnExpression>();
foreach (var property in keyProperties)
{
var columnExpression = CreateColumnExpression(property, table, tableReferenceExpression, nullable: false);
columns[property] = columnExpression;
joinColumns.Add(columnExpression);
_identifier.Add((columnExpression, property.GetKeyValueComparer()));
}
}
else
{
var innerColumns = keyProperties.Select(
p => CreateColumnExpression(p, table, tableReferenceExpression, nullable: false));

var entityProjection = new EntityProjectionExpression(entityType, propertyExpressions);
_projectionMapping[new ProjectionMember()] = entityProjection;
var joinPredicate = joinColumns.Zip(innerColumns, (l, r) => sqlExpressionFactory.Equal(l, r))
.Aggregate((l, r) => sqlExpressionFactory.AndAlso(l, r));

var primaryKey = entityType.FindPrimaryKey();
if (primaryKey != null)
{
foreach (var property in primaryKey.Properties)
{
_identifier.Add((propertyExpressions[property], property.GetKeyValueComparer()));
var joinExpression = new InnerJoinExpression(tableExpression, joinPredicate);
_removableJoinTables.Add(_tables.Count);
AddTable(joinExpression, tableReferenceExpression);
}
}

foreach (var property in entityType.GetProperties())
{
if (property.IsPrimaryKey())
{
continue;
}

var columnBase = mappings.Select(e => e.Table.FindColumn(property)).First(e => e != null)!;
columns[property] = CreateColumnExpression(
property, columnBase, tableReferenceExpressionMap[columnBase.Table], nullable: false);
}

var entityProjection = new EntityProjectionExpression(entityType, columns);
_projectionMapping[new ProjectionMember()] = entityProjection;
}
}
}

break;
}

static ITableBase GetTableBase(IEntityType entityType) => entityType.GetViewOrTableMappings().Single().Table;
void GenerateNonHierarchyNonSplittingEntityType(ITableBase table, TableExpressionBase tableExpression)
{
var tableReferenceExpression = new TableReferenceExpression(this, tableExpression.Alias!);
AddTable(tableExpression, tableReferenceExpression);

var propertyExpressions = new Dictionary<IProperty, ColumnExpression>();
foreach (var property in GetAllPropertiesInHierarchy(entityType))
{
propertyExpressions[property] = CreateColumnExpression(property, table, tableReferenceExpression, nullable: false);
}

var entityProjection = new EntityProjectionExpression(entityType, propertyExpressions);
_projectionMapping[new ProjectionMember()] = entityProjection;

var primaryKey = entityType.FindPrimaryKey();
if (primaryKey != null)
{
foreach (var property in primaryKey.Properties)
{
_identifier.Add((propertyExpressions[property], property.GetKeyValueComparer()));
}
}
}

static ITableBase GetTableBaseFiltered(IEntityType entityType, List<ITableBase> existingTables)
=> entityType.GetViewOrTableMappings().Single(m => !existingTables.Contains(m.Table)).Table;
Expand Down Expand Up @@ -1713,8 +1772,8 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
_projectionMapping.Clear();
select1._identifier.AddRange(_identifier);
_identifier.Clear();
select1._tptLeftJoinTables.AddRange(_tptLeftJoinTables);
_tptLeftJoinTables.Clear();
select1._removableJoinTables.AddRange(_removableJoinTables);
_removableJoinTables.Clear();
foreach (var kvp in _tpcDiscriminatorValues)
{
select1._tpcDiscriminatorValues[kvp.Key] = kvp.Value;
Expand Down Expand Up @@ -2902,8 +2961,8 @@ private SqlRemappingVisitor PushdownIntoSubqueryInternal()
Having = null;
Offset = null;
Limit = null;
subquery._tptLeftJoinTables.AddRange(_tptLeftJoinTables);
_tptLeftJoinTables.Clear();
subquery._removableJoinTables.AddRange(_removableJoinTables);
_removableJoinTables.Clear();
foreach (var kvp in _tpcDiscriminatorValues)
{
subquery._tpcDiscriminatorValues[kvp.Key] = kvp.Value;
Expand Down Expand Up @@ -3213,14 +3272,17 @@ private SelectExpression Prune(IReadOnlyCollection<string>? referencedColumns)
var columnExpressionFindingExpressionVisitor = new ColumnExpressionFindingExpressionVisitor();
var columnsMap = columnExpressionFindingExpressionVisitor.FindColumns(this);
var removedTableCount = 0;
// Start at 1 because we don't drop main table.
// Dropping main table is more complex because other tables need to unwrap joins to be main
for (var i = 0; i < _tables.Count; i++)
{
var table = _tables[i];
var tableAlias = GetAliasFromTableExpressionBase(table);
if (columnsMap[tableAlias] == null
&& (table is LeftJoinExpression
|| table is OuterApplyExpression)
&& _tptLeftJoinTables?.Contains(i + removedTableCount) == true)
|| table is OuterApplyExpression
|| table is InnerJoinExpression) // This is only valid for removable join table which are from entity splitting
&& _removableJoinTables?.Contains(i + removedTableCount) == true)
{
_tables.RemoveAt(i);
_tableReferences.RemoveAt(i);
Expand Down Expand Up @@ -3341,7 +3403,14 @@ private static ConcreteColumnExpression CreateColumnExpression(
ITableBase table,
TableReferenceExpression tableExpression,
bool nullable)
=> new(property, table.FindColumn(property)!, tableExpression, nullable);
=> CreateColumnExpression(property, table.FindColumn(property)!, tableExpression, nullable);

private static ConcreteColumnExpression CreateColumnExpression(
IProperty property,
IColumnBase columnBase,
TableReferenceExpression tableExpression,
bool nullable)
=> new(property, columnBase, tableExpression, nullable);

private ConcreteColumnExpression GenerateOuterColumn(
TableReferenceExpression tableReferenceExpression,
Expand Down Expand Up @@ -3578,7 +3647,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
_usedAliases = _usedAliases,
};
newSelectExpression._mutable = false;
newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables);
newSelectExpression._removableJoinTables.AddRange(_removableJoinTables);
foreach (var kvp in newTpcDiscriminatorValues)
{
newSelectExpression._tpcDiscriminatorValues[kvp.Key] = kvp.Value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ protected EntitySplittingTestBase(ITestOutputHelper testOutputHelper)
//TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

[ConditionalFact(Skip = "Entity splitting query Issue #620")]
[ConditionalFact]
public virtual async Task Can_roundtrip()
{
await InitializeAsync(OnModelCreating, sensitiveLogEnabled: false);
await InitializeAsync(OnModelCreating, sensitiveLogEnabled: true);

await using (var context = CreateContext())
{
Expand Down
Loading

0 comments on commit 43e0755

Please sign in to comment.