Skip to content

Commit

Permalink
Integrating custom metrics and inference parameters (#82)
Browse files Browse the repository at this point in the history
Signed-off-by: Gerald Shen <geshen@nvidia.com>
Signed-off-by: Igor Gitman <igitman@nvidia.com>
Co-authored-by: Gerald Shen <geshen@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com>
Co-authored-by: Gerald Shen <119401249+gshennvm@users.noreply.github.com>
  • Loading branch information
5 people committed Jan 26, 2024
1 parent dc97194 commit d06b23f
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 60 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### New features and optimizations
- Added public-facing official Dockerfile for NeMo-Aligner
- Memory optimization in PPO that helps avoid OOM in the actor when sending training data to the critic
- SFT: added support for custom validation metrics based on model generations

### Breaking changes

Expand All @@ -16,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
a dictionary from the training configuration.
- `exp_manager.max_time_per_run` is now respected, the trainers will save and run validation before exiting if we've reached the time limit.
- Fixed crash in PPO when using a separate reward model server (i.e., with `combine_rm_and_critic_server=False`).
- Fixed a crash when LR scheduler was not specified

## [0.1.0] - 2023-12-04
### Added
Expand Down
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ RUN git clone https://github.com/NVIDIA/NeMo.git && \
git checkout FETCH_HEAD; \
fi && \
pip uninstall -y nemo_toolkit sacrebleu && \
git cherry-pick --no-commit -X theirs fa8d416793d850f4ce56bea65e1fe28cc0d092c0 a7f0bc1903493888c31436efc2452ff721fa5a67 && \
git cherry-pick --no-commit -X theirs \
fa8d416793d850f4ce56bea65e1fe28cc0d092c0 \
a7f0bc1903493888c31436efc2452ff721fa5a67 \
7d3d9ac3b1aecf5786b5978a0c1e574701473c62 && \
sed -i 's/shutil.rmtree(ckpt_to_dir(filepath))/shutil.rmtree(ckpt_to_dir(filepath), ignore_errors=True)/g' nemo/collections/nlp/parts/nlp_overrides.py && \
rm -rf .git && pip install -e ".[nlp]" && \
cd nemo/collections/nlp/data/language_modeling/megatron && make
Expand Down
74 changes: 41 additions & 33 deletions examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ trainer:
limit_val_batches: 1.0
gradient_clip_val: 1.0

# can be used to register any custom metrics that require token-by-token generation
# inference_metrics:
# my_metric_name1:
# _target_: <metric class>
# my_metric_name2:
# _target_: <metric class>
# <any required arguments>

# do not change these
logger: False # logger provided by exp_manager
enable_checkpointing: False
Expand All @@ -39,7 +47,7 @@ exp_manager:
monitor: val_loss
save_top_k: 5
mode: min
save_nemo_on_train_end: False
save_nemo_on_train_end: False
filename: 'megatron_gpt_sft--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}'
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: False # need to keep this false otherwise it will create multiple last.ckpt files because restore reset the previous best model
Expand All @@ -50,7 +58,7 @@ model:
pipeline_model_parallel_size: 1 # inter-layer model parallelism
restore_from_path: ??? # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training.
save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: False
encoder_seq_length: 4096 # the sequence length of the encoder model, it will be overwriten by loaded GPT model
Expand All @@ -60,8 +68,8 @@ model:
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity
Expand All @@ -78,6 +86,28 @@ model:
attention_dropout: 0.0
ffn_dropout: 0.0

# can be used to customize behavior of model.generate for inference metrics
# note that you have to specify all parameters explicitly even if they match defaults
# as long as you change at least one parameter
#
# inference:
# sampling_params:
# use_greedy: False
# temperature: 0.7
# top_k: 0
# top_p: 0.95
# repetition_penalty: 1.0
# add_BOS: True
# all_probs: False
# compute_logprob: False
# end_strings: ["<|endoftext|>", "<extra_id_1>"]
# length_params:
# min_length: 0
# max_length: 512
# strategy:
# _target_: <custom strategy class>
# <any required arguments>

data:
chat: False # whether use chatbot data or not
chat_prompt_tokens: # special tokens for the chat prompts, a dictionary of {token_type: token}. note that some tokenizer may combine the characters at the junction between {end_of_turn}{turn_start}. e.g. '<im end><im start>', the '><' sometimes is merged to be a single token. This is not supported, try to avoid
Expand All @@ -86,12 +116,13 @@ model:
label_start: "\x12"
end_of_turn: "\x0A" # \0x0A is '\n'
end_of_name: "\x0A" # \0x0A is '\n'

sample: False # create the index mapping files for the sample data, so max_steps * global_batch_size can be larger than the dataset size
num_workers: 1
dataloader_type: single # only supports single
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
# file_names:
# - /path/to/squad.jsonl
# - /path/to/mnli.jsonl
# - /path/to/boolq.jsonl
Expand All @@ -110,7 +141,6 @@ model:
# - 0.5
# - 0.25
# - 0.25
concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random'
label_key: 'output'
add_eos: True
add_sep: False
Expand All @@ -119,11 +149,10 @@ model:
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']

validation_ds:
file_path: ??? # Path to a JSONL file corresponding to the source data. Data format is identical to validation_ds.
names: null # Names of the corresponding datasets used to log metrics.
global_batch_size: ${model.data.train_ds.global_batch_size}
micro_batch_size: ${model.data.train_ds.micro_batch_size}
shuffle: False
Expand All @@ -135,43 +164,22 @@ model:
add_eos: ${model.data.train_ds.add_eos}
add_sep: ${model.data.train_ds.add_sep}
add_bos: ${model.data.train_ds.add_bos}
write_predictions_to_file: False
output_file_path_prefix: null # Prefix of the file to write predictions to.
truncation_field: ${model.data.train_ds.truncation_field} # Options: keys in prompt_template
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
hf_dataset: False # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']

metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'rouge', 'token_f1']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

test_ds:
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
output_original_text: True # needed for the proper metrics support

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
lr: 3e-5
weight_decay: 0.01
betas:
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 10
constant_steps: 1000
min_lr: 9e-7

inference:
greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
compute_attention_mask: True
12 changes: 2 additions & 10 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import numpy as np
import torch
import torch.multiprocessing as mp
from megatron.core import parallel_state
from omegaconf.omegaconf import OmegaConf, open_dict

from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import get_prompt_template_example
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
)
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
Expand Down Expand Up @@ -94,7 +85,6 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
gpt_cfg.data.train_ds.prompt_template = prompt_template
gpt_cfg.data.validation_ds.prompt_template = prompt_template
gpt_cfg.data.test_ds.prompt_template = prompt_template

sft_cls = GPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"
Expand All @@ -105,6 +95,8 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

gpt_cfg.inference = cfg.model.get("inference", {})

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
Expand Down
23 changes: 19 additions & 4 deletions nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
from nemo.utils import logging


from nemo.utils import logging
from nemo_aligner.metrics import InferenceMetricsHandler
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches
Expand Down Expand Up @@ -72,6 +72,9 @@ def __init__(
reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX
)

# any metrics that require running full token-by-token inference during validation
self.inference_metrics_handler = InferenceMetricsHandler(cfg.get("inference_metrics"))

def validation_step(self, batch):
self.model.prepare_for_validation_step()

Expand All @@ -97,16 +100,22 @@ def run_validation(self):
loss_mean, metrics = self.validation_step(batch)
self.timer.stop("validation_step_time")
validation_step_time = self.timer.get("validation_step_time")

metrics["validation_step_time"] = validation_step_time

if self.inference_metrics_handler.has_metrics():
generation_output = self.run_generation(batch)
self.inference_metrics_handler.update(batch, generation_output)

loss_means.append(loss_mean)
for k, v in metrics.items():
val_metrics[k].append(v)
log_val_metrics = {f"val_{k}": v for k, v in metrics.items()}
val_pbar.set_postfix(log_val_metrics)

val_metrics = {k: mean(v) for k, v in val_metrics.items()}
val_metrics.update(self.inference_metrics_handler.compute())
self.inference_metrics_handler.reset()

return mean(loss_means), val_metrics

def train_single_step(self, batch):
Expand All @@ -124,7 +133,8 @@ def train_single_step(self, batch):
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.scheduler.step()
if self.scheduler is not None:
self.scheduler.step()

trainer_metrics = {}
if grad_norm is not None:
Expand All @@ -133,6 +143,10 @@ def train_single_step(self, batch):

return loss_mean, trainer_metrics | metrics

@torch.no_grad()
def run_generation(self, batch):
return self.model.infer({"text": batch["contexts"], "length": batch["context_lengths"]})

def fit(self):
if self.cfg.max_epochs is not None and self.cfg.max_epochs > 1:
# because we need to figure out a nice way to reset the shuffling on our dataset
Expand All @@ -158,6 +172,7 @@ def fit(self):
)

for _, batch in zip(loop_iter, global_pbar):

self.timer.start("train_step_time")
loss, metrics = self.train_single_step(batch)
self.timer.stop("train_step_time")
Expand Down
4 changes: 1 addition & 3 deletions nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,6 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i
index_mapping_dir=data_cfg.get("index_mapping_dir", None),
prompt_template=data_cfg.get("prompt_template", None),
virtual_tokens=0,
tokens_to_generate=data_cfg.get(
"tokens_to_generate", 0
), # used at inference time to allocate tensor positions for tokens that will be generated by inf procedure.
memmap_workers=data_cfg.get(
"memmap_workers", None
), # used to set num. of workers to create the memmap index files
Expand All @@ -293,6 +290,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i
"truncation_method", "right"
), # used to choose truncation method. Options: ['random', 'left', 'right']
special_tokens=special_tokens,
output_original_text=data_cfg.get("output_original_text", False),
)
return dataset

Expand Down
16 changes: 16 additions & 0 deletions nemo_aligner/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# importing all available metrics here to make it easier to import from the config
from nemo_aligner.metrics.common import InferenceMetricsHandler
54 changes: 54 additions & 0 deletions nemo_aligner/metrics/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# A couple of common metrics as well as a handler class to provide a unified interface.

from typing import Dict, Optional

import hydra
from omegaconf import DictConfig


class InferenceMetricsHandler:
"""A wrapper around metrics objects that will call update/compute/reset on all registered metrics.
If metrics_config is None, then all methods become no-ops and compute will return an empty dict.
"""

def __init__(self, metrics_config: Optional[DictConfig]):
if metrics_config is None:
metrics_config = {}
self.metrics = hydra.utils.instantiate(metrics_config)

def has_metrics(self) -> bool:
"""Returns True if there are metrics to compute."""
return len(self.metrics) > 0

def update(self, batch: Dict, generation_output: Dict):
"""Calling .update on all metrics.
Batch and generation output are coming directly from
validation dataloader and model.generate respectively.
"""
for metric in self.metrics.values():
metric.update(batch, generation_output)

def compute(self) -> Dict[str, float]:
"""Returns a dictionary with finalized metric values."""
return {name: metric.compute() for name, metric in self.metrics.items()}

def reset(self):
"""Will reset state of all metrics to prepare for the next validation run."""
for metric in self.metrics.values():
metric.reset()
Loading

0 comments on commit d06b23f

Please sign in to comment.