diff --git a/Sources/SQLite/Typed/Query.swift b/Sources/SQLite/Typed/Query.swift index 04665f39..feff2f69 100644 --- a/Sources/SQLite/Typed/Query.swift +++ b/Sources/SQLite/Typed/Query.swift @@ -1036,10 +1036,29 @@ extension Connection { let column = names.removeLast() let namespace = names.joined(separator: ".") + // Return a copy of the input "with" clause stripping all subclauses besides "select", "join", and "with". + func strip(_ with: WithClauses) -> WithClauses { + var stripped = WithClauses() + stripped.recursive = with.recursive + for subclause in with.clauses { + let query = subclause.query + var strippedQuery = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database) + strippedQuery.clauses.select = query.clauses.select + strippedQuery.clauses.join = query.clauses.join + strippedQuery.clauses.with = strip(query.clauses.with) + + var strippedSubclause = WithClauses.Clause(alias: subclause.alias, query: strippedQuery) + strippedSubclause.columns = subclause.columns + stripped.clauses.append(strippedSubclause) + } + return stripped + } + func expandGlob(_ namespace: Bool) -> (QueryType) throws -> Void { { (queryType: QueryType) throws -> Void in var query = type(of: queryType).init(queryType.clauses.from.name, database: queryType.clauses.from.database) query.clauses.select = queryType.clauses.select + query.clauses.with = strip(queryType.clauses.with) let expression = query.expression var names = try self.prepare(expression.template, expression.bindings).columnNames.map { $0.quote() } if namespace { names = names.map { "\(queryType.tableName().expression.template).\($0)" } } diff --git a/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift b/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift index 3fd388e9..f5105a09 100644 --- a/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift +++ b/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift @@ -276,6 +276,39 @@ class QueryIntegrationTests: SQLiteTestCase { XCTAssertEqual(21, sum) } + + /// Verify that `*` is properly expanded in a SELECT statement following a WITH clause. + func test_with_glob_expansion() throws { + let names = Table("names") + let name = Expression("name") + try db.run(names.create { builder in + builder.column(email) + builder.column(name) + }) + + try db.run(users.insert(email <- "alice@example.com")) + try db.run(names.insert(email <- "alice@example.com", name <- "Alice")) + + // WITH intermediate AS ( SELECT ... ) SELECT * FROM intermediate + let intermediate = Table("intermediate") + let rows = try db.prepare( + intermediate + .with(intermediate, + as: users + .select([id, users[email], name]) + .join(names, on: names[email] == users[email]) + .where(users[email] == "alice@example.com") + )) + + // There should be at least one row in the result. + let row = try XCTUnwrap(rows.makeIterator().next()) + + // Verify the column names + XCTAssertEqual(row.columnNames.count, 3) + XCTAssertNotNil(row[id]) + XCTAssertNotNil(row[name]) + XCTAssertNotNil(row[email]) + } } extension Connection {