Skip to content

Commit

Permalink
Manual model warmup to resolve AOT model warmup performance degradati…
Browse files Browse the repository at this point in the history
…on (#126)

* Implement manual model warmup to resolve performance degradation

* fix insert generate compiled

* remove check for JetStreamEngine in orchestrator

* pyink pylint fixes

* change references from aot to warmup

* fix non-empty comparison

* use all() to check True in entire lists
  • Loading branch information
vivianrwu authored Aug 14, 2024
1 parent e61532d commit 59538fc
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 99 deletions.
12 changes: 0 additions & 12 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ class ActiveRequest:
prefill_result: Any = None
#################### Information relevant for prefill ########################
prefill_content: Optional[str | list[int]] = None
padded_token_length: Optional[int] = None
################## Information relevant for detokenization ###################
# Which generate step this was added at.
generate_timestep_added: Optional[int] = None
Expand Down Expand Up @@ -513,19 +512,13 @@ def _prefill_thread(self, idx: int):
padded_tokens, true_length = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length
)
if isinstance(prefill_engine, engine_api.JetStreamEngine):
request.padded_token_length = token_utils.take_nearest_length(
prefill_engine.prefill_buckets, true_length
)
prefill_engine.set_padded_token_length(request.padded_token_length)

# Compute new kv cache for the prefill_content.
prefill_result, first_token = prefill_engine.prefill(
params=prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
)

request.prefill_result = prefill_result

# put first token to detokenize queue
Expand Down Expand Up @@ -722,11 +715,6 @@ def _generate_thread(self, idx: int):
generate_timestep,
)

if isinstance(generate_engine, engine_api.JetStreamEngine):
generate_engine.set_padded_token_length(
new_request.padded_token_length
)

decode_state = generate_engine.insert(
new_request.prefill_result, decode_state, slot=slot
)
Expand Down
8 changes: 4 additions & 4 deletions jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from jetstream.core import orchestrator
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.engine import aot_utils, engine_api
from jetstream.engine import warmup_utils, engine_api

from prometheus_client import start_http_server

Expand Down Expand Up @@ -107,7 +107,7 @@ def create_driver(
devices: Device objects, will be used to get engine with proper slicing.
jax_padding: The flag to enable JAX padding during tokenization.
metrics_collector: The JetStream Promethus metric collector.
enable_model_warmup: The flag to enable model server warmup with AOT.
enable_model_warmup: The flag to enable model server warmup.
Returns:
An orchestrator driver.
Expand Down Expand Up @@ -142,7 +142,7 @@ def create_driver(
]

try:
_ = aot_utils.layout_params_and_compile_executables(
_ = warmup_utils.layout_params_and_compile_executables(
prefill_engines, # pylint: disable=protected-access
generate_engines, # pylint: disable=protected-access
prefill_params, # pylint: disable=protected-access
Expand Down Expand Up @@ -191,7 +191,7 @@ def run(
metrics_server_config: The config to enable Promethus metric server.
enable_jax_profiler: The flag to enable JAX profiler server.
jax_profiler_port: The port JAX profiler server (default to 9999).
enable_model_warmup: The flag to enable model server warmup with AOT.
enable_model_warmup: The flag to enable model server warmup.
Returns:
JetStreamServer that wraps the grpc server and orchestrator driver.
Expand Down
22 changes: 4 additions & 18 deletions jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,13 @@ def colocated_cpus(self) -> Union[list[CpuDevices], None]:
class JetStreamEngine(Engine):
"""A wrapper engine of the Engine class.
JetStreamEngine defines the AOT warmed up model server engine.
JetStreamEngine defines the warmed up model server engine.
"""

def __init__(self, downstream_engine: Engine):
self._downstream_engine = downstream_engine

# Executables
self.prefill_executable = None
self.insert_executable = None
self.generate_executable = None

self.prefill_buckets = None

# Nearest right token length
self._padded_token_length = None

self.warm = False

def prefill(
Expand All @@ -284,9 +275,7 @@ def prefill(
true_length: int,
) -> Tuple[Prefix, ResultTokens]:

prefill_result, first_token = self.prefill_executable[
self.padded_token_length
](
prefill_result, first_token = self._downstream_engine.prefill(
params=params,
padded_tokens=padded_tokens,
true_length=true_length,
Expand All @@ -300,7 +289,7 @@ def insert(
slot: int,
) -> DecodeState:

decode_state = self.insert_executable[self.padded_token_length](
decode_state = self._downstream_engine.insert(
prefix=prefix,
decode_state=decode_state,
slot=slot,
Expand All @@ -310,7 +299,7 @@ def insert(
def generate(
self, params: Params, decode_state: DecodeState
) -> Tuple[DecodeState, ResultTokens]:
decode_state, sampled_tokens = self.generate_executable( # pylint: disable=not-callable
decode_state, sampled_tokens = self._downstream_engine.generate(
params=params, decode_state=decode_state
)
return decode_state, sampled_tokens
Expand Down Expand Up @@ -355,6 +344,3 @@ def mesh(self) -> jax.sharding.Mesh:
@property
def colocated_cpus(self) -> Union[list[CpuDevices], None]:
return self._downstream_engine.colocated_cpus

def set_padded_token_length(self, padded_token_length: int):
self.padded_token_length = padded_token_length
89 changes: 24 additions & 65 deletions jetstream/engine/aot_utils.py → jetstream/engine/warmup_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""AOT compilation utils."""
"""Model server warmup utils."""

import jax
import jax.numpy as jnp
import concurrent.futures
from typing import Any, Optional, cast
from typing import Any, Optional
import logging
from jetstream.engine import engine_api, token_utils

Expand All @@ -44,34 +43,30 @@ def layout_params_and_compile_executables(
any_prefill_engine = None
any_prefill_params = None

prefill_executables = []
inserts_generate_executables = []
prefills_compiled = []
inserts_generate_compiled = []

for i, pe in enumerate(prefill_engines):
any_prefill_engine = pe
any_prefill_params = prefill_params[i]
prefill_executable = initialize_prefill_jit_cache(
prefill_compiled = initialize_prefill_jit_cache(
prefill_engine=pe,
prefill_params=prefill_params[i],
prefill_idx=i,
)
prefill_executables.append(prefill_executable)
prefills_compiled.append(prefill_compiled)

for i, ge in enumerate(generate_engines):
insert_executable, generate_executable = (
initialize_insert_generate_jit_cache(
prefill_engine=any_prefill_engine,
generate_engine=ge,
prefill_params=any_prefill_params,
generate_params=generate_params[i],
generate_idx=i,
)
)
inserts_generate_executables.append(
[insert_executable, generate_executable]
insert_generate_compiled = initialize_insert_generate_jit_cache(
prefill_engine=any_prefill_engine,
generate_engine=ge,
prefill_params=any_prefill_params,
generate_params=generate_params[i],
generate_idx=i,
)
inserts_generate_compiled.append([insert_generate_compiled])

if prefill_executables and inserts_generate_executables:
if all(prefills_compiled) and all(inserts_generate_compiled):
return True
return False

Expand Down Expand Up @@ -104,47 +99,32 @@ def initialize_prefill_jit_cache(
def compile_prefill(length):
padded_tokens, true_length = jnp.ones((length), dtype="int32"), length

lowered = jax.jit(
prefill_engine._downstream_engine.prefill, # pylint: disable=protected-access
out_shardings=prefill_engine.get_prefix_destination_sharding(),
).lower(
_, _ = prefill_engine._downstream_engine.prefill( # pylint: disable=protected-access
params=prefill_params,
padded_tokens=padded_tokens,
true_length=true_length,
)
logging.info(
"---------Prefill engine %d lowered for prefill length %d.---------",
prefill_idx,
length,
)
compiled = lowered.compile()

logging.info(
"---------Prefill engine %d compiled for prefill length %d.---------",
prefill_idx,
length,
)
return compiled

logging.info("---------Prefill compilation %d begun.---------", prefill_idx)

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(prefill_buckets)
) as executor:
prefill_executable = list(executor.map(compile_prefill, prefill_buckets))

prefill_executable = {
k: cast(jax.stages.Compiled, e)
for k, e in zip(prefill_buckets, prefill_executable)
}
_ = executor.map(compile_prefill, prefill_buckets)

prefill_engine.prefill_executable = prefill_executable
prefill_engine.warm = True

logging.info(
"---------Prefill compilation %d complete.---------", prefill_idx
)

return prefill_executable
return prefill_engine.warm


def initialize_insert_generate_jit_cache(
Expand Down Expand Up @@ -184,39 +164,25 @@ def compile_insert(length):
true_length=true_length,
)

lowered = jax.jit(generate_engine._downstream_engine.insert).lower( # pylint: disable=protected-access
prefix=prefill, decode_state=decode_state, slot=1
)
logging.info(
"---------Generate engine %d lowered for insert length %d.---------",
generate_idx,
length,
)
compiled = lowered.compile()
generate_engine.insert(prefix=prefill, decode_state=decode_state, slot=0)

logging.info(
"---------Generate engine %d compiled for insert length %d.---------",
generate_idx,
length,
)
return compiled

def compile_generate():

logging.info(
"---------Generate compilation %d begun.---------", generate_idx
)

lowered = jax.jit(generate_engine._downstream_engine.generate).lower( # pylint: disable=protected-access
generate_engine._downstream_engine.generate( # pylint: disable=protected-access
params=generate_params,
decode_state=decode_state,
)
logging.info(
"---------Generate engine %d lowered.---------",
generate_idx,
)

compiled = lowered.compile()
logging.info(
"---------Generate engine %d compiled.---------",
generate_idx,
Expand All @@ -226,35 +192,28 @@ def compile_generate():
"---------Generate compilation %d complete.---------", generate_idx
)

return compiled

logging.info(
"---------Insertion generation compilation %d begun.---------",
generate_idx,
)

generate_executable = compile_generate()
compile_generate()

logging.info(
"---------Generate engine %d compiled generation step.---------",
generate_idx,
)
generate_engine.generate_executable = generate_executable

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(prefill_buckets)
) as executor:
insert_executable = list(executor.map(compile_insert, prefill_buckets))
_ = executor.map(compile_insert, prefill_buckets)

insert_executable = {
k: cast(jax.stages.Compiled, e)
for k, e in zip(prefill_buckets, insert_executable)
}
generate_engine.insert_executable = insert_executable
generate_engine.warm = True

logging.info(
"---------Insertion generation compilation %d complete.---------",
generate_idx,
)

return insert_executable, generate_executable
return generate_engine.warm

0 comments on commit 59538fc

Please sign in to comment.