Skip to content

Commit

Permalink
Various request time metrics (#121)
Browse files Browse the repository at this point in the history
* first commit

* nit

* fmt

* description tweak

* added more metrics

* nit

* nit

* default metadata values

* move `new_request.metadata.transfer_start_time = time.perf_counter()`

* avoid NoneType

* NoneType

* set transfer_end_time and fmt

* camel case -> snake case

* description update

* change descriptions

* fmt

* logs

* better logs

* changed timings

* observing queue duration metric

* buckets in sorted order

* buckets not in sorted order

* corrected times

* number of output tokens

* move prefill_start_time, enable debug, maybe correct len for num tokens in detokenize

* fmt

* correct lengths of output tokens based on debug

* debug transfer queue time

* remove log

* removed logs, almost final

* nits

* readd log

* change logs

* reomve log

* condence

* improve test coverage

* revert _abort_or_raise deletion

* start_time mandatory

* undo

* nit

* updated buckets

* added 'jetstream_time_per_request'

* nit

* add 'jetstream_wait_time_per_request'

* nit

* missing .metadata

* lint

* change order of params

* changed metric description

* Add metadata field to proto

* update proto

* tweak generated file

* tweak generated file

* update proto

* pylint

* generate protos

* change start time assignment

* .value

* CopyFrom

* change definition of queue duration metric

* Increase test coverage

* fixed assertions

* fmt

* incorrect prefill time

* Add license statements

* Protobuf Python Version

* fmt

* pylint
  • Loading branch information
Bslabe123 authored Aug 7, 2024
1 parent 3946afa commit d681995
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 52 deletions.
138 changes: 137 additions & 1 deletion jetstream/core/metrics/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import shortuuid
from prometheus_client import Counter, Gauge, Histogram

from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS


Expand All @@ -37,21 +36,46 @@ def __new__(cls):
documentation="Size of prefill queue",
labelnames=["id"],
)

_transfer_backlog = Gauge(
name="jetstream_transfer_backlog_size",
documentation="Size of transfer queue",
labelnames=["id", "idx"],
)

_generate_backlog = Gauge(
name="jetstream_generate_backlog_size",
documentation="Size of generate queue",
labelnames=["id", "idx"],
)

_queue_duration = Histogram(
name="jetstream_queue_duration",
documentation="The total time each request spends enqueued in seconds",
labelnames=["id"],
buckets=[
0.01,
0.02,
0.05,
0.1,
0.2,
0.5,
1.0,
2.0,
5.0,
10.0,
20.0,
50.0,
100.0,
],
)

_slots_used_percentage = Gauge(
name="jetstream_slots_used_percentage",
documentation="The percentage of decode slots currently being used",
labelnames=["id", "idx"],
)

_server_startup_latency = Gauge(
name="jetstream_server_startup_latency",
documentation="Total time taken to start the Jetstream server",
Expand Down Expand Up @@ -96,6 +120,100 @@ def __new__(cls):
labelnames=["id"],
)

_time_to_first_token = Histogram(
name="jetstream_time_to_first_token",
documentation="Time to first token per request in seconds",
labelnames=["id"],
buckets=[
0.001,
0.005,
0.01,
0.02,
0.04,
0.06,
0.08,
0.1,
0.25,
0.5,
0.75,
1.0,
2.5,
5.0,
7.5,
10.0,
],
)

_time_per_output_token = Histogram(
name="jetstream_time_per_output_token",
documentation="Average time per output token per request in seconds",
labelnames=["id"],
buckets=[
0.01,
0.025,
0.05,
0.075,
0.1,
0.15,
0.2,
0.3,
0.4,
0.5,
0.75,
1.0,
2.5,
],
)

_time_per_prefill_token = Histogram(
name="jetstream_time_per_prefill_token",
documentation="Prefill time per token per request in seconds",
labelnames=["id"],
buckets=[
0.00001,
0.00002,
0.00005,
0.0001,
0.0002,
0.0005,
0.001,
0.002,
0.005,
0.01,
0.02,
0.05,
0.1,
],
)

_time_per_request = Histogram(
name="jetstream_time_per_request",
documentation="End to end request latency in seconds",
labelnames=["id"],
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0],
)

_wait_time_per_request = Histogram(
name="jetstream_wait_time_per_request",
documentation="Time each request is not being prefilled or decoded",
labelnames=["id"],
buckets=[
0.01,
0.02,
0.05,
0.1,
0.2,
0.5,
1.0,
2.0,
5.0,
10.0,
20.0,
50.0,
100.0,
],
)

def get_prefill_backlog_metric(self):
return self._prefill_backlog.labels(id=self._id)

Expand All @@ -105,12 +223,30 @@ def get_transfer_backlog_metric(self, idx: int):
def get_generate_backlog_metric(self, idx: int):
return self._generate_backlog.labels(id=self._id, idx=idx)

def get_queue_duration(self):
return self._queue_duration.labels(id=self._id)

def get_slots_used_percentage_metric(self, idx: int):
return self._slots_used_percentage.labels(id=self._id, idx=idx)

def get_server_startup_latency_metric(self):
return self._server_startup_latency.labels(id=self._id)

def get_time_to_first_token(self):
return self._time_to_first_token.labels(id=self._id)

def get_time_per_output_token(self):
return self._time_per_output_token.labels(id=self._id)

def get_time_per_prefill_token(self):
return self._time_per_prefill_token.labels(id=self._id)

def get_time_per_request(self):
return self._time_per_request.labels(id=self._id)

def get_wait_time_per_request(self):
return self._wait_time_per_request.labels(id=self._id)

def get_request_input_length(self):
return self._request_input_length.labels(id=self._id)

Expand Down
99 changes: 94 additions & 5 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,24 @@
root.addHandler(handler)


@dataclasses.dataclass
class ActiveRequestMetadata:
"""Inference request metadata."""

start_time: Optional[float] = None

prefill_enqueue_time: Optional[float] = None
prefill_dequeue_time: Optional[float] = None

transfer_enqueue_time: Optional[float] = None
transfer_dequeue_time: Optional[float] = None

generate_enqueue_time: Optional[float] = None
generate_dequeue_time: Optional[float] = None

complete_time: Optional[float] = None


@dataclasses.dataclass
class ActiveRequest:
"""Current state of the driver."""
Expand All @@ -130,6 +148,8 @@ class ActiveRequest:
# Which generate step this was added at.
generate_timestep_added: Optional[int] = None
is_client_side_tokenization: Optional[bool] = False
################## Information relevant for metrics ###################
metadata: ActiveRequestMetadata = ActiveRequestMetadata()

def enqueue_samples(self, generated_samples: list[ReturnSample]):
"""Adds the generated sample(s) to return channel for current step.
Expand Down Expand Up @@ -477,10 +497,10 @@ def _prefill_thread(self, idx: int):
my_transfer_backlog = self._transfer_backlogs[idx]
# The prefill thread can just sleep until it has work to do.
request = self._prefill_backlog.get(block=True)
request_start_time = time.perf_counter()

if request is None:
break
request.metadata.prefill_dequeue_time = time.perf_counter()
is_bos = True
logging.info(
"Prefilling on prefill engine %d : prefill queue size, %d,"
Expand Down Expand Up @@ -511,8 +531,10 @@ def _prefill_thread(self, idx: int):
# put first token to detokenize queue
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
my_detokenize_backlog = self._detokenize_backlogs[idx]
request.metadata.transfer_enqueue_time = time.perf_counter()
my_detokenize_backlog.put(
(first_token, request, request_start_time), block=True
(first_token, request, request.metadata.prefill_dequeue_time),
block=True,
)

# Once prefill is complete, place it on the generation queue and block if
Expand All @@ -526,6 +548,15 @@ def _prefill_thread(self, idx: int):
if self._metrics_collector:
self._metrics_collector.get_request_input_length().observe(true_length)

if self._metrics_collector:
self._metrics_collector.get_time_per_prefill_token().observe(
(
request.metadata.transfer_enqueue_time
- request.metadata.prefill_dequeue_time
)
/ true_length
)

del prefill_result
del request

Expand Down Expand Up @@ -562,6 +593,7 @@ def _transfer_thread(self, idx: int):
new_request = transfer_backlog.get(block=True)
if new_request is None:
break
new_request.metadata.transfer_dequeue_time = time.perf_counter()
target_idx = min(
self._generate_backlogs.items(), key=lambda q: q[1].qsize()
)[0]
Expand All @@ -577,6 +609,7 @@ def _transfer_thread(self, idx: int):
# Transfer the info to the relevant generate slice.
self._transfer_prefill_result(new_request, target_idx)
# Place the request on the correct generate backlog and block if full.
new_request.metadata.generate_enqueue_time = time.perf_counter()
self._generate_backlogs[target_idx].put(new_request, block=True)
logging.info(
"Successfully transferred prefill "
Expand Down Expand Up @@ -649,6 +682,24 @@ def _generate_thread(self, idx: int):
block |= not self._transfer_backlogs[idx].empty()
try:
new_request = my_generate_backlog.get(block=block, timeout=1.0)
if new_request is None:
break
new_request.metadata.generate_dequeue_time = time.perf_counter()
if (
self._metrics_collector
and new_request.metadata.start_time is not None
):
self._metrics_collector.get_queue_duration().observe(
# Time in prefill queue
new_request.metadata.prefill_dequeue_time
- new_request.metadata.prefill_enqueue_time
# Time in transfer queue
+ new_request.metadata.transfer_dequeue_time
- new_request.metadata.transfer_enqueue_time
# Time in generate queue
+ new_request.metadata.generate_dequeue_time
- new_request.metadata.generate_enqueue_time
)
# Got free slot and new request, use them.
except queue.Empty:
# No new requests, we can't insert, so put back slot.
Expand Down Expand Up @@ -731,7 +782,7 @@ def _detokenize_thread(self, idx: int):
start_detokenize_time = time.time()
# prefill first token
if isinstance(data[0], engine_api.ResultTokens):
request_first_token, request, request_start_time = data
request_first_token, request, _ = data
request_first_token = request_first_token.convert_to_numpy()

results, complete = token_utils.process_result_tokens(
Expand All @@ -747,9 +798,14 @@ def _detokenize_thread(self, idx: int):
request.enqueue_samples(results)

first_token_return_time = time.perf_counter()
if self._metrics_collector:
self._metrics_collector.get_time_to_first_token().observe(
first_token_return_time - request.metadata.prefill_dequeue_time
)
logging.info(
"TTFT duration: %fms",
(first_token_return_time - request_start_time) * 1000,
(first_token_return_time - request.metadata.prefill_dequeue_time)
* 1000,
)
# generate step tokens
elif isinstance(data[1], engine_api.ResultTokens):
Expand All @@ -773,12 +829,41 @@ def _detokenize_thread(self, idx: int):
# Return some output samples.
request.enqueue_samples(results)
if request.complete.all():
request.metadata.complete_time = time.perf_counter()
request.return_channel.close()
if self._metrics_collector:
self._metrics_collector.get_request_output_length().observe(
result_tokens.get_result_at_slot(slot).lengths
)
self._metrics_collector.get_request_success_count_metric().inc()
request.return_channel.close()
self._metrics_collector.get_time_per_output_token().observe(
(
request.metadata.complete_time
- request.metadata.transfer_enqueue_time
)
/ result_tokens.get_result_at_slot(slot).lengths
)
self._metrics_collector.get_time_per_request().observe(
request.metadata.complete_time
- request.metadata.transfer_enqueue_time
)

if request.metadata.start_time:
total_time = (
request.metadata.complete_time
- request.metadata.start_time
)
prefill_time = (
request.metadata.transfer_enqueue_time
- request.metadata.prefill_dequeue_time
)
generate_time = (
request.metadata.complete_time
- request.metadata.generate_dequeue_time
)
self._metrics_collector.get_wait_time_per_request().observe(
total_time - prefill_time - generate_time
)
# Place the slot back on the free queue.
my_live_requests[slot] = None
my_slots.put(slot, block=False) # This should always have space.
Expand Down Expand Up @@ -895,6 +980,10 @@ async def Decode( # pylint: disable=invalid-overridden-method
prefill_content=prefill_content,
is_client_side_tokenization=is_client_side_tokenization,
return_channel=return_channel,
metadata=ActiveRequestMetadata(
start_time=request.metadata.start_time,
prefill_enqueue_time=time.perf_counter(),
),
)
# The first stage is being prefilled, all other stages are handled
# inside the driver (transfer, generate*N, detokenize).
Expand Down
11 changes: 10 additions & 1 deletion jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,17 @@ message DecodeRequest {
TextContent text_content = 5;
TokenContent token_content = 6;
}

message Metadata {
float start_time = 1;
}

oneof metadata_optional {
Metadata metadata = 7;
}

reserved 1, 2, 3;
// Next ID: 7
// Next ID: 8
}

message DecodeResponse {
Expand Down
Loading

0 comments on commit d681995

Please sign in to comment.