Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new interval based cost function #2972

Merged
merged 3 commits into from
May 17, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
<artifactId>druid-processing</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>io.druid</groupId>
<artifactId>druid-server</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>com.github.wnameless</groupId>
<artifactId>json-flattener</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package io.druid.server.coordinator;

import io.druid.timeline.DataSegment;
import org.joda.time.DateTime;
import org.joda.time.Interval;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;

@State(Scope.Benchmark)
public class CostBalancerStrategyBenchmark
{
private final static DateTime t0 = new DateTime("2016-01-01T01:00:00Z");

private List<DataSegment> segments;
private DataSegment segment;

int x1 = 2;
int y0 = 3;
int y1 = 4;

int n = 10000;

@Setup
public void setupDummyCluster()
{
segment = createSegment(t0);

Random r = new Random(1234);
segments = new ArrayList<>(n);
for(int i = 0; i < n; ++i) {
final DateTime t = t0.minusHours(r.nextInt(365 * 24) - 365*12);
segments.add(createSegment(t));
}
}

DataSegment createSegment(DateTime t)
{
return new DataSegment(
"test",
new Interval(t, t.plusHours(1)),
"v1",
null,
null,
null,
null,
0,
0
);
}

@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Fork(1)
public double measureCostStrategySingle() throws InterruptedException
{
double totalCost = 0;
for(DataSegment s : segments) {
totalCost += CostBalancerStrategy.computeJointSegmentsCost(segment, s);
}
return totalCost;
}

@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Fork(1)
public double measureIntervalPenalty() throws InterruptedException
{
return CostBalancerStrategy.intervalCost(x1, y0, y1);
}
}
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,11 @@
<artifactId>derbyclient</artifactId>
<version>10.11.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>

<!-- Test Scope -->
<dependency>
Expand Down
4 changes: 4 additions & 0 deletions server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@
<groupId>org.apache.derby</groupId>
<artifactId>derbyclient</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>

<!-- Tests -->
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,36 @@

package io.druid.server.coordinator;

import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.metamx.common.Pair;
import com.metamx.emitter.EmittingLogger;
import io.druid.timeline.DataSegment;
import org.joda.time.DateTime;
import org.apache.commons.math3.util.FastMath;
import org.joda.time.Interval;

import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;

public class CostBalancerStrategy implements BalancerStrategy
{
private static final EmittingLogger log = new EmittingLogger(CostBalancerStrategy.class);
private static final long DAY_IN_MILLIS = 1000 * 60 * 60 * 24;
private static final long SEVEN_DAYS_IN_MILLIS = 7 * DAY_IN_MILLIS;
private static final long THIRTY_DAYS_IN_MILLIS = 30 * DAY_IN_MILLIS;
private final long referenceTimestamp;
private final ListeningExecutorService exec;

public static long gapMillis(Interval interval1, Interval interval2)
{
if (interval1.getStartMillis() > interval2.getEndMillis()) {
return interval1.getStartMillis() - interval2.getEndMillis();
} else if (interval2.getStartMillis() > interval1.getEndMillis()) {
return interval2.getStartMillis() - interval1.getEndMillis();
} else {
return 0;
}
}
static final double HALF_LIFE = 24.0; // cost function half-life in hours
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be really nice to have some comments about what everything means and how the algo works

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

half life is by definition ln(2) / lambda i.e. the time difference that will make the joint cost go down by half

static final double LAMBDA = Math.log(2) / HALF_LIFE;
static final double INV_LAMBDA_SQUARE = 1 / (LAMBDA * LAMBDA);

public CostBalancerStrategy(DateTime referenceTimestamp, ListeningExecutorService exec)
private static final double MILLIS_IN_HOUR = 3_600_000.0;
private static final double MILLIS_FACTOR = MILLIS_IN_HOUR / LAMBDA;

private final ListeningExecutorService exec;

public CostBalancerStrategy(ListeningExecutorService exec)
{
this.referenceTimestamp = referenceTimestamp.getMillis();
this.exec = exec;
}

Expand All @@ -81,50 +73,107 @@ public ServerHolder findNewSegmentHomeBalancer(
return chooseBestServer(proposalSegment, serverHolders, true).rhs;
}

static double computeJointSegmentsCost(final DataSegment segment, final Iterable<DataSegment> segmentSet)
{
double totalCost = 0;
for(DataSegment s : segmentSet) {
totalCost += computeJointSegmentsCost(segment, s);
}
return totalCost;
}

/**
* This defines the unnormalized cost function between two segments. There is a base cost given by
* the minimum size of the two segments and additional penalties.
* recencyPenalty: it is more likely that recent segments will be queried together
* This defines the unnormalized cost function between two segments.
*
* dataSourcePenalty: if two segments belong to the same data source, they are more likely to be involved
* in the same queries
* gapPenalty: it is more likely that segments close together in time will be queried together
*
* @param segment1 The first DataSegment.
* @param segment2 The second DataSegment.
* intervalPenalty: it is more likely that segments close together in time will be queried together
*
* @param segmentA The first DataSegment.
* @param segmentB The second DataSegment.
*
* @return The joint cost of placing the two DataSegments together on one node.
*/
public double computeJointSegmentCosts(final DataSegment segment1, final DataSegment segment2)
public static double computeJointSegmentsCost(final DataSegment segmentA, final DataSegment segmentB)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move static fns to top of classes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probably also delete static method gapMillis now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already removed it.

final long gapMillis = gapMillis(segment1.getInterval(), segment2.getInterval());
final Interval intervalA = segmentA.getInterval();
final Interval intervalB = segmentB.getInterval();

final double baseCost = Math.min(segment1.getSize(), segment2.getSize());
double recencyPenalty = 1;
double dataSourcePenalty = 1;
double gapPenalty = 1;
final double t0 = intervalA.getStartMillis();
final double t1 = (intervalA.getEndMillis() - t0) / MILLIS_FACTOR;
final double start = (intervalB.getStartMillis() - t0) / MILLIS_FACTOR;
final double end = (intervalB.getEndMillis() - t0) / MILLIS_FACTOR;

if (segment1.getDataSource().equals(segment2.getDataSource())) {
dataSourcePenalty = 2;
}
final double multiplier = segmentA.getDataSource().equals(segmentB.getDataSource()) ? 2.0 : 1.0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is going on here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an extra penalty to co-locating segments from the same source. The idea is that a given query hits a single datasource, so this encourages segments from the same source to be more spread out than segments from different sources (which are unlikely to be queried at the same time).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the datasource multiplier, to ensure we avoid co-locating segments of the same datasource. I updated the PR description to add a note about this.


double segment1diff = referenceTimestamp - segment1.getInterval().getEndMillis();
double segment2diff = referenceTimestamp - segment2.getInterval().getEndMillis();
if (segment1diff < SEVEN_DAYS_IN_MILLIS && segment2diff < SEVEN_DAYS_IN_MILLIS) {
recencyPenalty = (2 - segment1diff / SEVEN_DAYS_IN_MILLIS) * (2 - segment2diff / SEVEN_DAYS_IN_MILLIS);
return INV_LAMBDA_SQUARE * intervalCost(t1, start, end) * multiplier;
}

/**
* Computes the joint cost of two intervals X = [x_0 = 0, x_1) and Y = [y_0, y_1)
*
* cost(X, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy $$
*
* lambda = 1 in this particular implementation
*
* Other values of lambda can be calculated by multiplying inputs by lambda
* and multiplying the result by 1 / lambda ^ 2
*
* Interval start and end are all relative to x_0.
* Therefore this function assumes x_0 = 0, x1 >= 0, and y1 > y0
*
* @param x1 end of interval X
* @param y0 start of interval Y
* @param y1 end o interval Y
* @return joint cost of X and Y
*/
public static double intervalCost(double x1, double y0, double y1)
{
if (x1 == 0 || y1 == y0) {
return 0;
}

/** gap is null if the two segment intervals overlap or if they're adjacent */
if (gapMillis == 0) {
gapPenalty = 2;
} else {
if (gapMillis < THIRTY_DAYS_IN_MILLIS) {
gapPenalty = 2 - gapMillis / THIRTY_DAYS_IN_MILLIS;
}
if(y0 < 0) {
// swap X and Y
double tmp = x1;
x1 = y1 - y0;
y1 = tmp - y0;
y0 = -y0;
}

final double cost = baseCost * recencyPenalty * dataSourcePenalty * gapPenalty;
// Y overlaps X
if (y0 < x1) {
/**
* X [ A )[ B )[ C ) or [ A )[ B )
* Y [ ) [ )[ C )
*
* A could be empty if y0 == 0
* C could be empty if y1 == x1
*
* cost(X, Y) = cost(A, Y) + cost(B, C) + cost(B, B)
*/
final double beta; // b1 - y0
final double gamma; // c1 - y0
if(y1 <= x1) {
beta = y1 - y0;
gamma = x1 - y0;
} else {
beta = x1 - y0;
gamma = y1 - y0;
}
return intervalCost(y0, y0, y1) +
intervalCost(beta, beta, gamma) +
// cost of exactly overlapping intervals of size beta
2 * (beta + FastMath.exp(-beta) - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a constant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I suppose this is just the solution to the integral?

Can the comment be described as such?

Copy link
Member Author

@xvrl xvrl May 17, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drcrallen this can't be a constant, since beta depends on interval start / end. And yes, this is the solution to
\int_0^{\beta} \int_{0}^{\beta} e^{|x-y|}dxdy = 2 \cdot (\beta + e^{-\beta} - 1)

screen shot 2016-05-16 at 9 14 26 pm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add some comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the constant out front before I realized its just the equation solution

} else {
final double exy0 = FastMath.exp(x1 - y0);
final double exy1 = FastMath.exp(x1 - y1);
final double ey0 = FastMath.exp(0f - y0);
final double ey1 = FastMath.exp(0f - y1);

return cost;
return (ey1 - ey0) - (exy1 - exy0);
}
}

public BalancerSegmentHolder pickSegmentToMove(final List<ServerHolder> serverHolders)
Expand All @@ -144,11 +193,9 @@ public double calculateInitialTotalCost(final List<ServerHolder> serverHolders)
{
double cost = 0;
for (ServerHolder server : serverHolders) {
DataSegment[] segments = server.getServer().getSegments().values().toArray(new DataSegment[]{});
for (int i = 0; i < segments.length; ++i) {
for (int j = i; j < segments.length; ++j) {
cost += computeJointSegmentCosts(segments[i], segments[j]);
}
Iterable<DataSegment> segments = server.getServer().getSegments().values();
for (DataSegment s : segments) {
cost += computeJointSegmentsCost(s, segments);
}
}
return cost;
Expand All @@ -168,8 +215,8 @@ public double calculateNormalization(final List<ServerHolder> serverHolders)
{
double cost = 0;
for (ServerHolder server : serverHolders) {
for (DataSegment segment : server.getServer().getSegments().values()) {
cost += computeJointSegmentCosts(segment, segment);
for (DataSegment segment : server.getServer().getSegments().values()) {
cost += computeJointSegmentsCost(segment, segment);
}
}
return cost;
Expand Down Expand Up @@ -211,17 +258,20 @@ protected double computeCost(
}

/** The contribution to the total cost of a given server by proposing to move the segment to that server is... */
double cost = 0f;
double cost = 0d;

/** the sum of the costs of other (exclusive of the proposalSegment) segments on the server */
for (DataSegment segment : server.getServer().getSegments().values()) {
if (!proposalSegment.equals(segment)) {
cost += computeJointSegmentCosts(proposalSegment, segment);
}
}
cost += computeJointSegmentsCost(
proposalSegment,
Iterables.filter(
server.getServer().getSegments().values(),
Predicates.not(Predicates.equalTo(proposalSegment))
)
);

/** plus the costs of segments that will be loaded */
for (DataSegment segment : server.getPeon().getSegmentsToLoad()) {
cost += computeJointSegmentCosts(proposalSegment, segment);
}
cost += computeJointSegmentsCost(proposalSegment, server.getPeon().getSegmentsToLoad());

return cost;
}
return Double.POSITIVE_INFINITY;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ public CostBalancerStrategyFactory(int costBalancerStrategyThreadCount)
}

@Override
public BalancerStrategy createBalancerStrategy(DateTime referenceTimestamp)
public CostBalancerStrategy createBalancerStrategy(DateTime referenceTimestamp)
{
return new CostBalancerStrategy(referenceTimestamp, exec);
return new CostBalancerStrategy(exec);
}

@Override
Expand Down
Loading