From 59538fc0512aef6d0728bd583f604f0bc9118d16 Mon Sep 17 00:00:00 2001 From: vivianrwu Date: Wed, 14 Aug 2024 15:53:09 -0700 Subject: [PATCH] Manual model warmup to resolve AOT model warmup performance degradation (#126) * Implement manual model warmup to resolve performance degradation * fix insert generate compiled * remove check for JetStreamEngine in orchestrator * pyink pylint fixes * change references from aot to warmup * fix non-empty comparison * use all() to check True in entire lists --- jetstream/core/orchestrator.py | 12 --- jetstream/core/server_lib.py | 8 +- jetstream/engine/engine_api.py | 22 +---- .../engine/{aot_utils.py => warmup_utils.py} | 89 +++++-------------- 4 files changed, 32 insertions(+), 99 deletions(-) rename jetstream/engine/{aot_utils.py => warmup_utils.py} (69%) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index cefabd05..a0c77c85 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -143,7 +143,6 @@ class ActiveRequest: prefill_result: Any = None #################### Information relevant for prefill ######################## prefill_content: Optional[str | list[int]] = None - padded_token_length: Optional[int] = None ################## Information relevant for detokenization ################### # Which generate step this was added at. generate_timestep_added: Optional[int] = None @@ -513,11 +512,6 @@ def _prefill_thread(self, idx: int): padded_tokens, true_length = self._process_prefill_content( request, tokenizer, is_bos, prefill_engine.max_prefill_length ) - if isinstance(prefill_engine, engine_api.JetStreamEngine): - request.padded_token_length = token_utils.take_nearest_length( - prefill_engine.prefill_buckets, true_length - ) - prefill_engine.set_padded_token_length(request.padded_token_length) # Compute new kv cache for the prefill_content. prefill_result, first_token = prefill_engine.prefill( @@ -525,7 +519,6 @@ def _prefill_thread(self, idx: int): padded_tokens=padded_tokens, true_length=true_length, ) - request.prefill_result = prefill_result # put first token to detokenize queue @@ -722,11 +715,6 @@ def _generate_thread(self, idx: int): generate_timestep, ) - if isinstance(generate_engine, engine_api.JetStreamEngine): - generate_engine.set_padded_token_length( - new_request.padded_token_length - ) - decode_state = generate_engine.insert( new_request.prefill_result, decode_state, slot=slot ) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 22180f09..b323286a 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -34,7 +34,7 @@ from jetstream.core import orchestrator from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2_grpc -from jetstream.engine import aot_utils, engine_api +from jetstream.engine import warmup_utils, engine_api from prometheus_client import start_http_server @@ -107,7 +107,7 @@ def create_driver( devices: Device objects, will be used to get engine with proper slicing. jax_padding: The flag to enable JAX padding during tokenization. metrics_collector: The JetStream Promethus metric collector. - enable_model_warmup: The flag to enable model server warmup with AOT. + enable_model_warmup: The flag to enable model server warmup. Returns: An orchestrator driver. @@ -142,7 +142,7 @@ def create_driver( ] try: - _ = aot_utils.layout_params_and_compile_executables( + _ = warmup_utils.layout_params_and_compile_executables( prefill_engines, # pylint: disable=protected-access generate_engines, # pylint: disable=protected-access prefill_params, # pylint: disable=protected-access @@ -191,7 +191,7 @@ def run( metrics_server_config: The config to enable Promethus metric server. enable_jax_profiler: The flag to enable JAX profiler server. jax_profiler_port: The port JAX profiler server (default to 9999). - enable_model_warmup: The flag to enable model server warmup with AOT. + enable_model_warmup: The flag to enable model server warmup. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index cba42939..5277f6df 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -257,22 +257,13 @@ def colocated_cpus(self) -> Union[list[CpuDevices], None]: class JetStreamEngine(Engine): """A wrapper engine of the Engine class. - JetStreamEngine defines the AOT warmed up model server engine. + JetStreamEngine defines the warmed up model server engine. """ def __init__(self, downstream_engine: Engine): self._downstream_engine = downstream_engine - # Executables - self.prefill_executable = None - self.insert_executable = None - self.generate_executable = None - self.prefill_buckets = None - - # Nearest right token length - self._padded_token_length = None - self.warm = False def prefill( @@ -284,9 +275,7 @@ def prefill( true_length: int, ) -> Tuple[Prefix, ResultTokens]: - prefill_result, first_token = self.prefill_executable[ - self.padded_token_length - ]( + prefill_result, first_token = self._downstream_engine.prefill( params=params, padded_tokens=padded_tokens, true_length=true_length, @@ -300,7 +289,7 @@ def insert( slot: int, ) -> DecodeState: - decode_state = self.insert_executable[self.padded_token_length]( + decode_state = self._downstream_engine.insert( prefix=prefix, decode_state=decode_state, slot=slot, @@ -310,7 +299,7 @@ def insert( def generate( self, params: Params, decode_state: DecodeState ) -> Tuple[DecodeState, ResultTokens]: - decode_state, sampled_tokens = self.generate_executable( # pylint: disable=not-callable + decode_state, sampled_tokens = self._downstream_engine.generate( params=params, decode_state=decode_state ) return decode_state, sampled_tokens @@ -355,6 +344,3 @@ def mesh(self) -> jax.sharding.Mesh: @property def colocated_cpus(self) -> Union[list[CpuDevices], None]: return self._downstream_engine.colocated_cpus - - def set_padded_token_length(self, padded_token_length: int): - self.padded_token_length = padded_token_length diff --git a/jetstream/engine/aot_utils.py b/jetstream/engine/warmup_utils.py similarity index 69% rename from jetstream/engine/aot_utils.py rename to jetstream/engine/warmup_utils.py index 65b61f87..6bf7c26a 100644 --- a/jetstream/engine/aot_utils.py +++ b/jetstream/engine/warmup_utils.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""AOT compilation utils.""" +"""Model server warmup utils.""" -import jax import jax.numpy as jnp import concurrent.futures -from typing import Any, Optional, cast +from typing import Any, Optional import logging from jetstream.engine import engine_api, token_utils @@ -44,34 +43,30 @@ def layout_params_and_compile_executables( any_prefill_engine = None any_prefill_params = None - prefill_executables = [] - inserts_generate_executables = [] + prefills_compiled = [] + inserts_generate_compiled = [] for i, pe in enumerate(prefill_engines): any_prefill_engine = pe any_prefill_params = prefill_params[i] - prefill_executable = initialize_prefill_jit_cache( + prefill_compiled = initialize_prefill_jit_cache( prefill_engine=pe, prefill_params=prefill_params[i], prefill_idx=i, ) - prefill_executables.append(prefill_executable) + prefills_compiled.append(prefill_compiled) for i, ge in enumerate(generate_engines): - insert_executable, generate_executable = ( - initialize_insert_generate_jit_cache( - prefill_engine=any_prefill_engine, - generate_engine=ge, - prefill_params=any_prefill_params, - generate_params=generate_params[i], - generate_idx=i, - ) - ) - inserts_generate_executables.append( - [insert_executable, generate_executable] + insert_generate_compiled = initialize_insert_generate_jit_cache( + prefill_engine=any_prefill_engine, + generate_engine=ge, + prefill_params=any_prefill_params, + generate_params=generate_params[i], + generate_idx=i, ) + inserts_generate_compiled.append([insert_generate_compiled]) - if prefill_executables and inserts_generate_executables: + if all(prefills_compiled) and all(inserts_generate_compiled): return True return False @@ -104,47 +99,32 @@ def initialize_prefill_jit_cache( def compile_prefill(length): padded_tokens, true_length = jnp.ones((length), dtype="int32"), length - lowered = jax.jit( - prefill_engine._downstream_engine.prefill, # pylint: disable=protected-access - out_shardings=prefill_engine.get_prefix_destination_sharding(), - ).lower( + _, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access params=prefill_params, padded_tokens=padded_tokens, true_length=true_length, ) - logging.info( - "---------Prefill engine %d lowered for prefill length %d.---------", - prefill_idx, - length, - ) - compiled = lowered.compile() + logging.info( "---------Prefill engine %d compiled for prefill length %d.---------", prefill_idx, length, ) - return compiled logging.info("---------Prefill compilation %d begun.---------", prefill_idx) with concurrent.futures.ThreadPoolExecutor( max_workers=len(prefill_buckets) ) as executor: - prefill_executable = list(executor.map(compile_prefill, prefill_buckets)) - - prefill_executable = { - k: cast(jax.stages.Compiled, e) - for k, e in zip(prefill_buckets, prefill_executable) - } + _ = executor.map(compile_prefill, prefill_buckets) - prefill_engine.prefill_executable = prefill_executable prefill_engine.warm = True logging.info( "---------Prefill compilation %d complete.---------", prefill_idx ) - return prefill_executable + return prefill_engine.warm def initialize_insert_generate_jit_cache( @@ -184,22 +164,13 @@ def compile_insert(length): true_length=true_length, ) - lowered = jax.jit(generate_engine._downstream_engine.insert).lower( # pylint: disable=protected-access - prefix=prefill, decode_state=decode_state, slot=1 - ) - logging.info( - "---------Generate engine %d lowered for insert length %d.---------", - generate_idx, - length, - ) - compiled = lowered.compile() + generate_engine.insert(prefix=prefill, decode_state=decode_state, slot=0) logging.info( "---------Generate engine %d compiled for insert length %d.---------", generate_idx, length, ) - return compiled def compile_generate(): @@ -207,16 +178,11 @@ def compile_generate(): "---------Generate compilation %d begun.---------", generate_idx ) - lowered = jax.jit(generate_engine._downstream_engine.generate).lower( # pylint: disable=protected-access + generate_engine._downstream_engine.generate( # pylint: disable=protected-access params=generate_params, decode_state=decode_state, ) - logging.info( - "---------Generate engine %d lowered.---------", - generate_idx, - ) - compiled = lowered.compile() logging.info( "---------Generate engine %d compiled.---------", generate_idx, @@ -226,30 +192,23 @@ def compile_generate(): "---------Generate compilation %d complete.---------", generate_idx ) - return compiled - logging.info( "---------Insertion generation compilation %d begun.---------", generate_idx, ) - generate_executable = compile_generate() + compile_generate() + logging.info( "---------Generate engine %d compiled generation step.---------", generate_idx, ) - generate_engine.generate_executable = generate_executable with concurrent.futures.ThreadPoolExecutor( max_workers=len(prefill_buckets) ) as executor: - insert_executable = list(executor.map(compile_insert, prefill_buckets)) + _ = executor.map(compile_insert, prefill_buckets) - insert_executable = { - k: cast(jax.stages.Compiled, e) - for k, e in zip(prefill_buckets, insert_executable) - } - generate_engine.insert_executable = insert_executable generate_engine.warm = True logging.info( @@ -257,4 +216,4 @@ def compile_generate(): generate_idx, ) - return insert_executable, generate_executable + return generate_engine.warm