Skip to content

Commit

Permalink
[ML] Better memory estimation for NLP models (#568)
Browse files Browse the repository at this point in the history
This PR adds an ability to estimate per deployment and per allocation memory usage of NLP transformer models. It uses torch.profiler and performs logs the peak memory usage during the inference.

This information is then used in Elasticsearch to provision models with sufficient memory (elastic/elasticsearch#98874).
  • Loading branch information
valeriy42 committed Nov 6, 2023
1 parent 28e6d92 commit 6cecb45
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 5 deletions.
9 changes: 7 additions & 2 deletions .buildkite/run-elasticsearch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ fi

set -euxo pipefail

SCRIPT_PATH=$(dirname $(realpath -s $0))
# realpath on MacOS use different flags than on Linux
if [[ "$OSTYPE" == "darwin"* ]]; then
SCRIPT_PATH=$(dirname $(realpath $0))
else
SCRIPT_PATH=$(dirname $(realpath -s $0))
fi

moniker=$(echo "$ELASTICSEARCH_VERSION" | tr -C "[:alnum:]" '-')
suffix=rest-test
Expand Down Expand Up @@ -132,7 +137,7 @@ url="http://elastic:$ELASTIC_PASSWORD@$NODE_NAME"
docker_pull_attempts=0
until [ "$docker_pull_attempts" -ge 5 ]
do
docker pull docker.elastic.co/elasticsearch/"$ELASTICSEARCH_VERSION" && break
docker pull docker.elastic.co/elasticsearch/$ELASTICSEARCH_VERSION && break
docker_pull_attempts=$((docker_pull_attempts+1))
sleep 10
done
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ currently using a minimum version of PyCharm 2019.2.4.
* Setup Elasticsearch instance with docker

``` bash
> ELASTICSEARCH_VERSION=elasticsearch:7.x-SNAPSHOT .ci/run-elasticsearch.sh
> ELASTICSEARCH_VERSION=elasticsearch:8.x-SNAPSHOT BUILDKITE=false .buildkite/run-elasticsearch.sh
```

* Now check `http://localhost:9200`
Expand Down
4 changes: 4 additions & 0 deletions eland/ml/pytorch/traceable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ def save(self, path: str) -> str:
trace_model = torch.jit.freeze(trace_model)
torch.jit.save(trace_model, model_path)
return model_path

@property
def model(self) -> nn.Module:
return self._model
139 changes: 137 additions & 2 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import json
import os.path
import random
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Union
Expand All @@ -30,6 +31,7 @@
import transformers # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from torch import Tensor, nn
from torch.profiler import profile # type: ignore
from transformers import (
AutoConfig,
AutoModel,
Expand Down Expand Up @@ -270,8 +272,8 @@ def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
_token_type_ids: Tensor,
_position_ids: Tensor,
_token_type_ids: Tensor = None,
_position_ids: Tensor = None,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""

Expand Down Expand Up @@ -769,13 +771,146 @@ def _create_config(
tokenization=tokenization_config
)

# add static and dynamic memory state size to metadata
per_deployment_memory_bytes = self._get_per_deployment_memory()

per_allocation_memory_bytes = self._get_per_allocation_memory(
tokenization_config.max_sequence_length, 1
)

metadata = {
"per_deployment_memory_bytes": per_deployment_memory_bytes,
"per_allocation_memory_bytes": per_allocation_memory_bytes,
}

return NlpTrainedModelConfig(
description=f"Model {self._model_id} for task type '{self._task_type}'",
model_type="pytorch",
inference_config=inference_config,
input=TrainedModelInput(
field_names=["text_field"],
),
metadata=metadata,
)

def _get_per_deployment_memory(self) -> float:
"""
Returns the static memory size of the model in bytes.
"""
psize: float = sum(
param.nelement() * param.element_size()
for param in self._traceable_model.model.parameters()
)
bsize: float = sum(
buffer.nelement() * buffer.element_size()
for buffer in self._traceable_model.model.buffers()
)
return psize + bsize

def _get_per_allocation_memory(
self, max_seq_length: Optional[int], batch_size: int
) -> float:
"""
Returns the transient memory size of the model in bytes.
Parameters
----------
max_seq_length : Optional[int]
Maximum sequence length to use for the model.
batch_size : int
Batch size to use for the model.
"""
activities = [torch.profiler.ProfilerActivity.CPU]

# Get the memory usage of the model with a batch size of 1.
inputs_1 = self._get_model_inputs(max_seq_length, 1)
with profile(activities=activities, profile_memory=True) as prof:
self._traceable_model.model(*inputs_1)
mem1: float = prof.key_averages().total_average().cpu_memory_usage

# This is measuring memory usage of the model with a batch size of 2 and
# then linearly extrapolating it to get the memory usage of the model for
# a batch size of batch_size.
if batch_size == 1:
return mem1
inputs_2 = self._get_model_inputs(max_seq_length, 2)
with profile(activities=activities, profile_memory=True) as prof:
self._traceable_model.model(*inputs_2)
mem2: float = prof.key_averages().total_average().cpu_memory_usage
return mem1 + (mem2 - mem1) * (batch_size - 1)

def _get_model_inputs(
self,
max_length: Optional[int],
batch_size: int,
) -> Tuple[Tensor, ...]:
"""
Returns a random batch of inputs for the model.
Parameters
----------
max_length : Optional[int]
Maximum sequence length to use for the model. Default is 512.
batch_size : int
Batch size to use for the model.
"""
vocab: list[str] = list(self._tokenizer.get_vocab().keys())

# if optional max_length is not set, set it to 512
if max_length is None:
max_length = 512

# generate random text
texts: list[str] = [
" ".join(random.choices(vocab, k=max_length)) for _ in range(batch_size)
]

# tokenize text
inputs: transformers.BatchEncoding = self._tokenizer(
texts,
padding="max_length",
return_tensors="pt",
truncation=True,
max_length=max_length,
)

return self._make_inputs_compatible(inputs)

def _make_inputs_compatible(
self, inputs: transformers.BatchEncoding
) -> Tuple[Tensor, ...]:
""" "
Make the input batch format compatible to the model's requirements.
Parameters
----------
inputs : transformers.BatchEncoding
The input batch to make compatible.
"""
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
if "token_type_ids" not in inputs:
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(
self._tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
),
):
del inputs["token_type_ids"]
return (inputs["input_ids"], inputs["attention_mask"])

position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
inputs["position_ids"] = position_ids
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["position_ids"],
)

def _create_traceable_model(self) -> _TransformerTraceableModel:
Expand Down
9 changes: 9 additions & 0 deletions tests/ml/pytorch/test_pytorch_model_config_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ def test_text_prediction(
assert ["text_field"] == config.input.field_names
assert isinstance(config.inference_config, config_type)
tokenization = config.inference_config.tokenization
assert isinstance(config.metadata, dict)
assert (
"per_deployment_memory_bytes" in config.metadata
and config.metadata["per_deployment_memory_bytes"] > 0
)
assert (
"per_allocation_memory_bytes" in config.metadata
and config.metadata["per_allocation_memory_bytes"] > 0
)
assert isinstance(tokenization, tokenizer_type)
assert max_sequence_len == tokenization.max_sequence_length

Expand Down

0 comments on commit 6cecb45

Please sign in to comment.