diff --git a/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md index 96c9db81..24bebb56 100644 --- a/docs/online-inference-with-maxtext-engine.md +++ b/docs/online-inference-with-maxtext-engine.md @@ -45,16 +45,16 @@ You can run the JetStream MaxText Server with Gemma and Llama2 models. This sect ### Use a Gemma model checkpoint * You can download a [Gemma checkpoint from Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText/variations/7b). -* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. +* After downloading orbax Gemma checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively. * `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}` * Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`. * Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint. ```bash -# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} +# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For gemma-7b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} ``` Note: For more information about the Gemma model and checkpoints, see [About Gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma/Run_Gemma.md). @@ -63,25 +63,25 @@ Note: For more information about the Gemma model and checkpoints, see [About Gem ### Use a Llama2 model checkpoint * You can use a Llama2 checkpoint you have generated or one from [the open source community](https://llama.meta.com/llama-downloads/). -* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. +* After downloading PyTorch checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively. * `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}` * Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`. * Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint. ```bash -# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} +# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For llama2-7b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} # For llama2-13b -bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} +bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET} ${MAXTEXT_BUCKET_SCANNED} ${MAXTEXT_BUCKET_UNSCANNED} ``` Note: For more information about the Llama2 model and checkpoints, see [About Llama2](https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md). -## Step4: Run the JetStream MaxText server +## Step 4: Run the JetStream MaxText server ### Create model config environment variables for server flags @@ -104,8 +104,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=gemma-7b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 @@ -122,8 +122,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-7b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 @@ -131,8 +131,6 @@ export PER_DEVICE_BATCH_SIZE=11 #### Create Llama2-13b environment variables for server flags - - * Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server ```bash @@ -142,8 +140,8 @@ export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-13b export ICI_FSDP_PARALLELISM=1 -export ICI_AUTOREGRESSIVE_PARALLELISM=-1 -export ICI_TENSOR_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=1 +export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=4 @@ -187,7 +185,8 @@ python MaxText/maxengine_server.py \ Note: these flags are from [MaxText config](https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml) -## Step 5: Send test request to JetStream MaxText server +## Step 5: Send a test request to JetStream MaxText server +In a new tab in your terminal, run the following command ```bash cd ~ @@ -207,32 +206,95 @@ Response: to be a fan ## Step 6: Run benchmarks with JetStream MaxText server -Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following: +Note: The JetStream MaxText Server commands from Step 4 are not running with any quantization optimizations. To get the best benchmark results, we need to enable quantization for weights and KV cache. To do this, first generate AQT trained or fine-tuned checkpoints. Then, add the quantization flags and restart the server. + +### Generating a quantized checkpoint + +There are several different quantization configurations to choose from: + +#### int8 DRQ quantized checkpoint +```bash +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` +#### Weights-only int8 quantized checkpoint ```bash -# Enable int8 quantization for both weights and KV cache +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8w save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +#### Mixed precision weight-only quantized checkpoint +First, update the mixed precision config file (`MaxText/configs/quantization/mp_scale.json`) in MaxText repo to the mixed-precision-config defined below. +``` +{ + ".*/query": {"bits": 4, "scale": 0.8}, + ".*/key": {"bits": 4, "scale": 0.9}, + ".*/value": {"bits": 8}, + ".*/out": {"bits": 4}, + ".*/wi_0": {"bits": 4}, + ".*/wo": {"bits": 8} +} +``` +Then run the following command: +```bash +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=intmp +quant_cfg_path=configs/quantization/mp_scale.json save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +``` + +### Restart the server with quantization flags + +#### Set flags + +Setting base quantization flags +```bash +# To load an int8 DRQcheckpoint export QUANTIZATION=int8 -export QUANTIZE_KVCACHE=true +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True + +# To load an int8 weight-only checkpoint +export QUANTIZATION=int8w +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True + +# To load a Mixed-Precision quantized checkpoint +# If using Mixed-Precision mode, make sure to update the mixed precision config file to the same file as used for quantizing the checkpoint (MaxText/configs/quantization/mp_scale.json) +export QUANTIZATION=intmp +export LOAD_PARAMETERS_PATH${SAVE_QUANT_PARAMS_PATH} +export CHECKPOINT_IS_QUANTIZED=True +export QUANT_CFG_PATH=configs/quantization/mp_scale.json +``` + +The KV-cache is quantized to int8 by using the following config params +```bash +export QUANTIZE_KVCACHE=True +``` +If you don't want to quantize the KV-cache, set +```bash +export QUANTIZE_KVCACHE=False +``` + +#### Restart server +```bash # For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. export PER_DEVICE_BATCH_SIZE=12 cd ~/maxtext python MaxText/maxengine_server.py \ -MaxText/configs/base.yml \ -tokenizer_path=${TOKENIZER_PATH} \ -load_parameters_path=${LOAD_PARAMETERS_PATH} \ -max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ -max_target_length=${MAX_TARGET_LENGTH} \ -model_name=${MODEL_NAME} \ -ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ -ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ -ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ -scan_layers=${SCAN_LAYERS} \ -weight_dtype=${WEIGHT_DTYPE} \ -per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ -quantization=${QUANTIZATION} \ -quantize_kvcache=${QUANTIZE_KVCACHE} + MaxText/configs/base.yml \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${LOAD_PARAMETERS_PATH} \ + max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ + max_target_length=${MAX_TARGET_LENGTH} \ + model_name=${MODEL_NAME} \ + ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ + ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + scan_layers=${SCAN_LAYERS} \ + weight_dtype=${WEIGHT_DTYPE} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + quantization=${QUANTIZATION} \ + quantize_kvcache=${QUANTIZE_KVCACHE} ``` ### Benchmarking Gemma-7b @@ -262,10 +324,10 @@ python JetStream/benchmarks/benchmark_serving.py \ --warmup-mode sampled ``` -### Benchmarking Llama2-\*b +### Benchmarking Llama2 ```bash -# Same as Gemma-7b except for the tokenizer (must use a tokenizer that matches your model, which should now be tokenizer.llama2). +# The command is the same as that for the Gemma-7b, except for the tokenizer. Since we need to use a tokenizer that matches the model, it should now be tokenizer.llama2. python JetStream/benchmarks/benchmark_serving.py \ --tokenizer maxtext/assets/tokenizer.llama2 \ @@ -283,10 +345,11 @@ python JetStream/benchmarks/benchmark_serving.py \ # Clean up gcs buckets. gcloud storage buckets delete ${MODEL_BUCKET} gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY} -gcloud storage buckets delete ${DATASET_PATH} + # Clean up repositories. rm -rf maxtext rm -rf JetStream + # Clean up python virtual environment rm -rf .env ``` diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index 8e2b4d83..0340dbfe 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -28,25 +28,21 @@ export MODEL=$1 export MODEL_VARIATION=$2 export MODEL_NAME=${MODEL}-${MODEL_VARIATION} -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET # Please use separate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). # Point these variables to a GCS bucket that you created. # An example of CHKPT_BUCKET could be: gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION} export CHKPT_BUCKET=$3 -export MODEL_BUCKET=gs://${USER}-maxtext +export MODEL_BUCKET=$4 -# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run. -export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs - -# Point `DATASET_PATH` to the GCS bucket where you have your training data. -export DATASET_PATH=gs://${USER}-maxtext-dataset +# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint. +export BASE_OUTPUT_DIRECTORY=$5 export BUCKET_LOCATION=US # Create three GCS buckets for the demo. gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true -gcloud storage buckets create ${DATASET_PATH} --location=${BUCKET_LOCATION} || true # Convert model checkpoints to MaxText compatible checkpoints. if [ "$MODEL" == "gemma" ]; then