diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 4320327c..dc8a00e9 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -17,7 +17,6 @@ import os import shortuuid from prometheus_client import Counter, Gauge, Histogram - from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS @@ -37,21 +36,46 @@ def __new__(cls): documentation="Size of prefill queue", labelnames=["id"], ) + _transfer_backlog = Gauge( name="jetstream_transfer_backlog_size", documentation="Size of transfer queue", labelnames=["id", "idx"], ) + _generate_backlog = Gauge( name="jetstream_generate_backlog_size", documentation="Size of generate queue", labelnames=["id", "idx"], ) + + _queue_duration = Histogram( + name="jetstream_queue_duration", + documentation="The total time each request spends enqueued in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + _slots_used_percentage = Gauge( name="jetstream_slots_used_percentage", documentation="The percentage of decode slots currently being used", labelnames=["id", "idx"], ) + _server_startup_latency = Gauge( name="jetstream_server_startup_latency", documentation="Total time taken to start the Jetstream server", @@ -96,6 +120,100 @@ def __new__(cls): labelnames=["id"], ) + _time_to_first_token = Histogram( + name="jetstream_time_to_first_token", + documentation="Time to first token per request in seconds", + labelnames=["id"], + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ) + + _time_per_output_token = Histogram( + name="jetstream_time_per_output_token", + documentation="Average time per output token per request in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + + _time_per_prefill_token = Histogram( + name="jetstream_time_per_prefill_token", + documentation="Prefill time per token per request in seconds", + labelnames=["id"], + buckets=[ + 0.00001, + 0.00002, + 0.00005, + 0.0001, + 0.0002, + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + ], + ) + + _time_per_request = Histogram( + name="jetstream_time_per_request", + documentation="End to end request latency in seconds", + labelnames=["id"], + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], + ) + + _wait_time_per_request = Histogram( + name="jetstream_wait_time_per_request", + documentation="Time each request is not being prefilled or decoded", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ) + def get_prefill_backlog_metric(self): return self._prefill_backlog.labels(id=self._id) @@ -105,12 +223,30 @@ def get_transfer_backlog_metric(self, idx: int): def get_generate_backlog_metric(self, idx: int): return self._generate_backlog.labels(id=self._id, idx=idx) + def get_queue_duration(self): + return self._queue_duration.labels(id=self._id) + def get_slots_used_percentage_metric(self, idx: int): return self._slots_used_percentage.labels(id=self._id, idx=idx) def get_server_startup_latency_metric(self): return self._server_startup_latency.labels(id=self._id) + def get_time_to_first_token(self): + return self._time_to_first_token.labels(id=self._id) + + def get_time_per_output_token(self): + return self._time_per_output_token.labels(id=self._id) + + def get_time_per_prefill_token(self): + return self._time_per_prefill_token.labels(id=self._id) + + def get_time_per_request(self): + return self._time_per_request.labels(id=self._id) + + def get_wait_time_per_request(self): + return self._wait_time_per_request.labels(id=self._id) + def get_request_input_length(self): return self._request_input_length.labels(id=self._id) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 2c54a6f8..cefabd05 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -109,6 +109,24 @@ root.addHandler(handler) +@dataclasses.dataclass +class ActiveRequestMetadata: + """Inference request metadata.""" + + start_time: Optional[float] = None + + prefill_enqueue_time: Optional[float] = None + prefill_dequeue_time: Optional[float] = None + + transfer_enqueue_time: Optional[float] = None + transfer_dequeue_time: Optional[float] = None + + generate_enqueue_time: Optional[float] = None + generate_dequeue_time: Optional[float] = None + + complete_time: Optional[float] = None + + @dataclasses.dataclass class ActiveRequest: """Current state of the driver.""" @@ -130,6 +148,8 @@ class ActiveRequest: # Which generate step this was added at. generate_timestep_added: Optional[int] = None is_client_side_tokenization: Optional[bool] = False + ################## Information relevant for metrics ################### + metadata: ActiveRequestMetadata = ActiveRequestMetadata() def enqueue_samples(self, generated_samples: list[ReturnSample]): """Adds the generated sample(s) to return channel for current step. @@ -477,10 +497,10 @@ def _prefill_thread(self, idx: int): my_transfer_backlog = self._transfer_backlogs[idx] # The prefill thread can just sleep until it has work to do. request = self._prefill_backlog.get(block=True) - request_start_time = time.perf_counter() if request is None: break + request.metadata.prefill_dequeue_time = time.perf_counter() is_bos = True logging.info( "Prefilling on prefill engine %d : prefill queue size, %d," @@ -511,8 +531,10 @@ def _prefill_thread(self, idx: int): # put first token to detokenize queue request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) my_detokenize_backlog = self._detokenize_backlogs[idx] + request.metadata.transfer_enqueue_time = time.perf_counter() my_detokenize_backlog.put( - (first_token, request, request_start_time), block=True + (first_token, request, request.metadata.prefill_dequeue_time), + block=True, ) # Once prefill is complete, place it on the generation queue and block if @@ -526,6 +548,15 @@ def _prefill_thread(self, idx: int): if self._metrics_collector: self._metrics_collector.get_request_input_length().observe(true_length) + if self._metrics_collector: + self._metrics_collector.get_time_per_prefill_token().observe( + ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + / true_length + ) + del prefill_result del request @@ -562,6 +593,7 @@ def _transfer_thread(self, idx: int): new_request = transfer_backlog.get(block=True) if new_request is None: break + new_request.metadata.transfer_dequeue_time = time.perf_counter() target_idx = min( self._generate_backlogs.items(), key=lambda q: q[1].qsize() )[0] @@ -577,6 +609,7 @@ def _transfer_thread(self, idx: int): # Transfer the info to the relevant generate slice. self._transfer_prefill_result(new_request, target_idx) # Place the request on the correct generate backlog and block if full. + new_request.metadata.generate_enqueue_time = time.perf_counter() self._generate_backlogs[target_idx].put(new_request, block=True) logging.info( "Successfully transferred prefill " @@ -649,6 +682,24 @@ def _generate_thread(self, idx: int): block |= not self._transfer_backlogs[idx].empty() try: new_request = my_generate_backlog.get(block=block, timeout=1.0) + if new_request is None: + break + new_request.metadata.generate_dequeue_time = time.perf_counter() + if ( + self._metrics_collector + and new_request.metadata.start_time is not None + ): + self._metrics_collector.get_queue_duration().observe( + # Time in prefill queue + new_request.metadata.prefill_dequeue_time + - new_request.metadata.prefill_enqueue_time + # Time in transfer queue + + new_request.metadata.transfer_dequeue_time + - new_request.metadata.transfer_enqueue_time + # Time in generate queue + + new_request.metadata.generate_dequeue_time + - new_request.metadata.generate_enqueue_time + ) # Got free slot and new request, use them. except queue.Empty: # No new requests, we can't insert, so put back slot. @@ -731,7 +782,7 @@ def _detokenize_thread(self, idx: int): start_detokenize_time = time.time() # prefill first token if isinstance(data[0], engine_api.ResultTokens): - request_first_token, request, request_start_time = data + request_first_token, request, _ = data request_first_token = request_first_token.convert_to_numpy() results, complete = token_utils.process_result_tokens( @@ -747,9 +798,14 @@ def _detokenize_thread(self, idx: int): request.enqueue_samples(results) first_token_return_time = time.perf_counter() + if self._metrics_collector: + self._metrics_collector.get_time_to_first_token().observe( + first_token_return_time - request.metadata.prefill_dequeue_time + ) logging.info( "TTFT duration: %fms", - (first_token_return_time - request_start_time) * 1000, + (first_token_return_time - request.metadata.prefill_dequeue_time) + * 1000, ) # generate step tokens elif isinstance(data[1], engine_api.ResultTokens): @@ -773,12 +829,41 @@ def _detokenize_thread(self, idx: int): # Return some output samples. request.enqueue_samples(results) if request.complete.all(): + request.metadata.complete_time = time.perf_counter() + request.return_channel.close() if self._metrics_collector: self._metrics_collector.get_request_output_length().observe( result_tokens.get_result_at_slot(slot).lengths ) self._metrics_collector.get_request_success_count_metric().inc() - request.return_channel.close() + self._metrics_collector.get_time_per_output_token().observe( + ( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + / result_tokens.get_result_at_slot(slot).lengths + ) + self._metrics_collector.get_time_per_request().observe( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + + if request.metadata.start_time: + total_time = ( + request.metadata.complete_time + - request.metadata.start_time + ) + prefill_time = ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + generate_time = ( + request.metadata.complete_time + - request.metadata.generate_dequeue_time + ) + self._metrics_collector.get_wait_time_per_request().observe( + total_time - prefill_time - generate_time + ) # Place the slot back on the free queue. my_live_requests[slot] = None my_slots.put(slot, block=False) # This should always have space. @@ -895,6 +980,10 @@ async def Decode( # pylint: disable=invalid-overridden-method prefill_content=prefill_content, is_client_side_tokenization=is_client_side_tokenization, return_channel=return_channel, + metadata=ActiveRequestMetadata( + start_time=request.metadata.start_time, + prefill_enqueue_time=time.perf_counter(), + ), ) # The first stage is being prefilled, all other stages are handled # inside the driver (transfer, generate*N, detokenize). diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 60c65605..f06d89d5 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -50,8 +50,17 @@ message DecodeRequest { TextContent text_content = 5; TokenContent token_content = 6; } + + message Metadata { + float start_time = 1; + } + + oneof metadata_optional { + Metadata metadata = 7; + } + reserved 1, 2, 3; - // Next ID: 7 + // Next ID: 8 } message DecodeResponse { diff --git a/jetstream/core/proto/jetstream_pb2.py b/jetstream/core/proto/jetstream_pb2.py index c4be62d5..0b146032 100644 --- a/jetstream/core/proto/jetstream_pb2.py +++ b/jetstream/core/proto/jetstream_pb2.py @@ -26,7 +26,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x8a\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' + b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xfc\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x12;\n\x08metadata\x18\x07 \x01(\x0b\x32\'.jetstream_proto.DecodeRequest.MetadataH\x01\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x1a\x1e\n\x08Metadata\x12\x12\n\nstart_time\x18\x01 \x01(\x02\x42\t\n\x07\x63ontentB\x13\n\x11metadata_optionalJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -37,23 +37,25 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None _globals["_DECODEREQUEST"]._serialized_start = 58 - _globals["_DECODEREQUEST"]._serialized_end = 324 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 233 - _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 260 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 262 - _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 295 - _globals["_DECODERESPONSE"]._serialized_start = 327 - _globals["_DECODERESPONSE"]._serialized_end = 658 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 493 - _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 509 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 512 - _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 641 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 600 - _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 641 - _globals["_HEALTHCHECKREQUEST"]._serialized_start = 660 - _globals["_HEALTHCHECKREQUEST"]._serialized_end = 680 - _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 682 - _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 720 - _globals["_ORCHESTRATOR"]._serialized_start = 723 - _globals["_ORCHESTRATOR"]._serialized_end = 908 + _globals["_DECODEREQUEST"]._serialized_end = 438 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 294 + _globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 321 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 323 + _globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 356 + _globals["_DECODEREQUEST_METADATA"]._serialized_start = 358 + _globals["_DECODEREQUEST_METADATA"]._serialized_end = 388 + _globals["_DECODERESPONSE"]._serialized_start = 441 + _globals["_DECODERESPONSE"]._serialized_end = 772 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 607 + _globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 623 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 626 + _globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 755 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 714 + _globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 755 + _globals["_HEALTHCHECKREQUEST"]._serialized_start = 774 + _globals["_HEALTHCHECKREQUEST"]._serialized_end = 794 + _globals["_HEALTHCHECKRESPONSE"]._serialized_start = 796 + _globals["_HEALTHCHECKRESPONSE"]._serialized_end = 834 + _globals["_ORCHESTRATOR"]._serialized_start = 837 + _globals["_ORCHESTRATOR"]._serialized_end = 1022 # @@protoc_insertion_point(module_scope) diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index e7dabfed..aaced235 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -16,6 +16,7 @@ import json import logging +import time from typing import Sequence from absl import app as abslapp from absl import flags @@ -63,7 +64,11 @@ def root(): @router.post("/v1/generate") async def generate(request: DecodeRequest): + start_time = time.perf_counter() proto_request = Parse(request.json(), jetstream_pb2.DecodeRequest()) + metadata = jetstream_pb2.DecodeRequest.Metadata() + metadata.start_time = start_time + proto_request.metadata.CopyFrom(metadata) generator = llm_orchestrator.Decode(proto_request) return StreamingResponse( content=proto_to_json_generator(generator), media_type="text/event-stream" diff --git a/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py index fb003386..cbb8dc6a 100644 --- a/jetstream/entrypoints/http/protocol.py +++ b/jetstream/entrypoints/http/protocol.py @@ -25,10 +25,15 @@ class TokenContent(BaseModel): token_ids: list[int] +class Metadata(BaseModel): + start_time: float + + class DecodeRequest(BaseModel): max_tokens: int text_content: TextContent | None = None token_content: TokenContent | None = None + metadata: Metadata | None = None # Config to enforce the oneof behavior at runtime. class Config: diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 9114f2fd..2fdddce9 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -40,6 +40,7 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): # Uses weight 2 for prefill, 4 for decode. ( config_lib.CPUTestServer, + True, ["Ċ", "Ō", "Ɵ", ""], [266, 332, 415, None], [None, None], @@ -47,6 +48,15 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): # Uses the same prefill / generate weights (2). ( config_lib.InterleavedCPUTestServer, + True, + ["Ċ", "Ə", "ɖ", ""], + [266, 399, 598, None], + [None], + ), + # Disable the metrics server. + ( + config_lib.InterleavedCPUTestServer, + False, ["Ċ", "Ə", "ɖ", ""], [266, 399, 598, None], [None], @@ -56,6 +66,7 @@ class ServerTest(unittest.IsolatedAsyncioTestCase): async def test_server( self, config: Type[config_lib.ServerConfig], + metrics_enabled: bool, expected_text: list[str], expected_token_ids: list[int | None], devices: list[Any], @@ -63,6 +74,7 @@ async def test_server( """Sets up a server and requests token responses.""" ######################### Server side ###################################### port = portpicker.pick_unused_port() + metrics_port = portpicker.pick_unused_port() print("port: " + str(port)) credentials = grpc.local_server_credentials() @@ -72,11 +84,15 @@ async def test_server( config=config, devices=devices, credentials=credentials, + metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port) + if metrics_enabled is True + else None, ) ###################### Requester side ###################################### - # prometheus not configured, assert no metrics collector on Driver - assert server._driver._metrics_collector is None # pylint: disable=protected-access + # if prometheus not configured, assert no metrics collector on Driver + if metrics_enabled is not True: + assert server._driver._metrics_collector is None # pylint: disable=protected-access async with grpc.aio.secure_channel( f"localhost:{port}", grpc.local_channel_credentials() @@ -106,31 +122,17 @@ async def test_server( assert output_text == expected_text[counter] assert output_token_id == expected_token_ids[counter] counter += 1 + # assert prometheus server is running and responding + if metrics_enabled is True: + assert server._driver._metrics_collector is not None # pylint: disable=protected-access + assert ( + requests.get( + f"http://localhost:{metrics_port}", timeout=5 + ).status_code + == requests.status_codes.codes["ok"] + ) server.stop() - def test_prometheus_server(self): - port = portpicker.pick_unused_port() - metrics_port = portpicker.pick_unused_port() - - print("port: " + str(port)) - print("metrics port: " + str(metrics_port)) - credentials = grpc.local_server_credentials() - # Now test server with prometheus config - server = server_lib.run( - port=port, - config=config_lib.InterleavedCPUTestServer, - devices=[None], - credentials=credentials, - metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port), - ) - # assert prometheus server is running and responding - assert server._driver._metrics_collector is not None # pylint: disable=protected-access - assert ( - requests.get(f"http://localhost:{metrics_port}", timeout=5).status_code - == requests.status_codes.codes["ok"] - ) - server.stop() - def test_jax_profiler_server(self): port = portpicker.pick_unused_port() print("port: " + str(port))