Skip to content

Commit

Permalink
[Backport] Dart and security backports (#17249) (#17278) (#17281) (#1…
Browse files Browse the repository at this point in the history
…7282) (#17283) (#17277) (#17285)

* MSQ: Allow for worker gaps. (#17277)
* DartSqlResource: Sort queries by start time. (#17282)
* DartSqlResource: Add controllerHost to GetQueriesResponse. (#17283)
* DartWorkerModule: Replace en dash with regular dash. (#17281)
* DartSqlResource: Return HTTP 202 on cancellation even if no such query. (#17278)
* Upgraded Protobuf to 3.25.5 (#17249)
---------
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
(cherry picked from commit 7d9e6d3)
---------
Co-authored-by: Gian Merlino <gianmerlino@gmail.com>
Co-authored-by: Shivam Garg <shigarg@visa.com>
  • Loading branch information
cryptoe authored Oct 8, 2024
1 parent f43964a commit ccb7c2e
Show file tree
Hide file tree
Showing 20 changed files with 499 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public enum State
private final ControllerContext controllerContext;
private final String sqlQueryId;
private final String sql;
private final String controllerHost;
private final AuthenticationResult authenticationResult;
private final DateTime startTime;
private final AtomicReference<State> state = new AtomicReference<>(State.ACCEPTED);
Expand All @@ -68,6 +69,7 @@ public ControllerHolder(
final ControllerContext controllerContext,
final String sqlQueryId,
final String sql,
final String controllerHost,
final AuthenticationResult authenticationResult,
final DateTime startTime
)
Expand All @@ -76,6 +78,7 @@ public ControllerHolder(
this.controllerContext = controllerContext;
this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId");
this.sql = sql;
this.controllerHost = controllerHost;
this.authenticationResult = authenticationResult;
this.startTime = Preconditions.checkNotNull(startTime, "startTime");
}
Expand All @@ -95,6 +98,11 @@ public String getSql()
return sql;
}

public String getControllerHost()
{
return controllerHost;
}

public AuthenticationResult getAuthenticationResult()
{
return authenticationResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.druid.msq.dart.controller.ControllerHolder;
import org.apache.druid.msq.util.MSQTaskQueryMakerUtils;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.server.DruidNode;
import org.joda.time.DateTime;

import java.util.Objects;
Expand All @@ -38,6 +39,7 @@ public class DartQueryInfo
private final String sqlQueryId;
private final String dartQueryId;
private final String sql;
private final String controllerHost;
private final String authenticator;
private final String identity;
private final DateTime startTime;
Expand All @@ -48,6 +50,7 @@ public DartQueryInfo(
@JsonProperty("sqlQueryId") final String sqlQueryId,
@JsonProperty("dartQueryId") final String dartQueryId,
@JsonProperty("sql") final String sql,
@JsonProperty("controllerHost") final String controllerHost,
@JsonProperty("authenticator") final String authenticator,
@JsonProperty("identity") final String identity,
@JsonProperty("startTime") final DateTime startTime,
Expand All @@ -57,6 +60,7 @@ public DartQueryInfo(
this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId");
this.dartQueryId = Preconditions.checkNotNull(dartQueryId, "dartQueryId");
this.sql = sql;
this.controllerHost = controllerHost;
this.authenticator = authenticator;
this.identity = identity;
this.startTime = startTime;
Expand All @@ -69,6 +73,7 @@ public static DartQueryInfo fromControllerHolder(final ControllerHolder holder)
holder.getSqlQueryId(),
holder.getController().queryId(),
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(holder.getSql()),
holder.getControllerHost(),
holder.getAuthenticationResult().getAuthenticatedBy(),
holder.getAuthenticationResult().getIdentity(),
holder.getStartTime(),
Expand Down Expand Up @@ -104,6 +109,16 @@ public String getSql()
return sql;
}

/**
* Controller host:port, from {@link DruidNode#getHostAndPortToUse()}, that is executing this query.
*/
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public String getControllerHost()
{
return controllerHost;
}

/**
* Authenticator that authenticated the identity from {@link #getIdentity()}.
*/
Expand Down Expand Up @@ -145,7 +160,7 @@ public String getState()
*/
public DartQueryInfo withoutAuthenticationResult()
{
return new DartQueryInfo(sqlQueryId, dartQueryId, sql, null, null, startTime, state);
return new DartQueryInfo(sqlQueryId, dartQueryId, sql, controllerHost, null, null, startTime, state);
}

@Override
Expand All @@ -161,6 +176,7 @@ public boolean equals(Object o)
return Objects.equals(sqlQueryId, that.sqlQueryId)
&& Objects.equals(dartQueryId, that.dartQueryId)
&& Objects.equals(sql, that.sql)
&& Objects.equals(controllerHost, that.controllerHost)
&& Objects.equals(authenticator, that.authenticator)
&& Objects.equals(identity, that.identity)
&& Objects.equals(startTime, that.startTime)
Expand All @@ -170,7 +186,7 @@ public boolean equals(Object o)
@Override
public int hashCode()
{
return Objects.hash(sqlQueryId, dartQueryId, sql, authenticator, identity, startTime, state);
return Objects.hash(sqlQueryId, dartQueryId, sql, controllerHost, authenticator, identity, startTime, state);
}

@Override
Expand All @@ -180,10 +196,11 @@ public String toString()
"sqlQueryId='" + sqlQueryId + '\'' +
", dartQueryId='" + dartQueryId + '\'' +
", sql='" + sql + '\'' +
", controllerHost='" + controllerHost + '\'' +
", authenticator='" + authenticator + '\'' +
", identity='" + identity + '\'' +
", startTime=" + startTime +
", state=" + state +
", state='" + state + '\'' +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ public GetQueriesResponse doGetRunningQueries(
controllerRegistry.getAllHolders()
.stream()
.map(DartQueryInfo::fromControllerHolder)
.sorted(Comparator.comparing(DartQueryInfo::getStartTime))
.collect(Collectors.toList());

// Add queries from all other servers, if "selfOnly" is not set.
Expand All @@ -172,6 +171,9 @@ public GetQueriesResponse doGetRunningQueries(
}
}

// Sort queries by start time, breaking ties by query ID, so the list comes back in a consistent and nice order.
queries.sort(Comparator.comparing(DartQueryInfo::getStartTime).thenComparing(DartQueryInfo::getDartQueryId));

final GetQueriesResponse response;
if (stateReadAccess.isAllowed()) {
// User can READ STATE, so they can see all running queries, as well as authentication details.
Expand Down Expand Up @@ -237,7 +239,10 @@ public Response cancelQuery(

List<SqlLifecycleManager.Cancelable> cancelables = sqlLifecycleManager.getAll(sqlQueryId);
if (cancelables.isEmpty()) {
return Response.status(Response.Status.NOT_FOUND).build();
// Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all
// Brokers, this ensures the user sees a successful request.
AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req);
return Response.status(Response.Status.ACCEPTED).build();
}

final Access access = authorizeCancellation(req, cancelables);
Expand All @@ -247,14 +252,12 @@ public Response cancelQuery(

// Don't call cancel() on the cancelables. That just cancels native queries, which is useless here. Instead,
// get the controller and stop it.
boolean found = false;
for (SqlLifecycleManager.Cancelable cancelable : cancelables) {
final HttpStatement stmt = (HttpStatement) cancelable;
final Object dartQueryId = stmt.context().get(DartSqlEngine.CTX_DART_QUERY_ID);
if (dartQueryId instanceof String) {
final ControllerHolder holder = controllerRegistry.get((String) dartQueryId);
if (holder != null) {
found = true;
holder.cancel();
}
} else {
Expand All @@ -267,7 +270,9 @@ public Response cancelQuery(
}
}

return Response.status(found ? Response.Status.ACCEPTED : Response.Status.NOT_FOUND).build();
// Return ACCEPTED even if the query wasn't found. When the Router broadcasts cancellation requests to all
// Brokers, this ensures the user sees a successful request.
return Response.status(Response.Status.ACCEPTED).build();
} else {
return Response.status(Response.Status.FORBIDDEN).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ public QueryResponse<Object[]> runQuery(DruidQuery druidQuery)
controllerContext,
plannerContext.getSqlQueryId(),
plannerContext.getSql(),
controllerContext.selfNode().getHostAndPortToUse(),
plannerContext.getAuthenticationResult(),
DateTimes.nowUtc()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public DartWorkerRunner createWorkerRunner(
final AuthorizerMapper authorizerMapper
)
{
final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dartworker-%s");
final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dart-worker-%s");
final File baseTempDir =
new File(processingConfig.getTmpDir(), StringUtils.format("dart_%s", selfNode.getPortToUse()));
return new DartWorkerRunner(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ public static ReadablePartition striped(final int stageNumber, final int numWork
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
}

/**
* Returns an output partition that is striped across a set of {@code workerNumbers}.
*/
public static ReadablePartition striped(
final int stageNumber,
final IntSortedSet workerNumbers,
final int partitionNumber
)
{
return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
}

/**
* Returns an output partition that has been collected onto a single worker.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2IntSortedMap;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;

import java.util.Collections;
import java.util.List;
Expand All @@ -39,6 +40,7 @@
@JsonSubTypes(value = {
@JsonSubTypes.Type(name = "collected", value = CollectedReadablePartitions.class),
@JsonSubTypes.Type(name = "striped", value = StripedReadablePartitions.class),
@JsonSubTypes.Type(name = "sparseStriped", value = SparseStripedReadablePartitions.class),
@JsonSubTypes.Type(name = "combined", value = CombinedReadablePartitions.class)
})
public interface ReadablePartitions extends Iterable<ReadablePartition>
Expand All @@ -59,7 +61,7 @@ static ReadablePartitions empty()
/**
* Combines various sets of partitions into a single set.
*/
static CombinedReadablePartitions combine(List<ReadablePartitions> readablePartitions)
static ReadablePartitions combine(List<ReadablePartitions> readablePartitions)
{
return new CombinedReadablePartitions(readablePartitions);
}
Expand All @@ -68,7 +70,7 @@ static CombinedReadablePartitions combine(List<ReadablePartitions> readableParti
* Returns a set of {@code numPartitions} partitions striped across {@code numWorkers} workers: each worker contains
* a "stripe" of each partition.
*/
static StripedReadablePartitions striped(
static ReadablePartitions striped(
final int stageNumber,
final int numWorkers,
final int numPartitions
Expand All @@ -82,11 +84,36 @@ static StripedReadablePartitions striped(
return new StripedReadablePartitions(stageNumber, numWorkers, partitionNumbers);
}

/**
* Returns a set of {@code numPartitions} partitions striped across {@code workers}: each worker contains
* a "stripe" of each partition.
*/
static ReadablePartitions striped(
final int stageNumber,
final IntSortedSet workers,
final int numPartitions
)
{
final IntAVLTreeSet partitionNumbers = new IntAVLTreeSet();
for (int i = 0; i < numPartitions; i++) {
partitionNumbers.add(i);
}

if (workers.lastInt() == workers.size() - 1) {
// Dense worker set. Use StripedReadablePartitions for compactness (send a single number rather than the
// entire worker set) and for backwards compatibility (older workers cannot understand
// SparseStripedReadablePartitions).
return new StripedReadablePartitions(stageNumber, workers.size(), partitionNumbers);
} else {
return new SparseStripedReadablePartitions(stageNumber, workers, partitionNumbers);
}
}

/**
* Returns a set of partitions that have been collected onto specific workers: each partition is on exactly
* one worker.
*/
static CollectedReadablePartitions collected(
static ReadablePartitions collected(
final int stageNumber,
final Map<Integer, Integer> partitionToWorkerMap
)
Expand Down
Loading

0 comments on commit ccb7c2e

Please sign in to comment.