Skip to content

Commit

Permalink
Add kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Jul 16, 2024
1 parent 0015c98 commit 1ba025f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
24 changes: 20 additions & 4 deletions llmfoundry/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,28 @@ def train(cfg: DictConfig) -> Trainer:

if fsdp_config is not None:
if 'load_planner' in fsdp_config:
load_planner_name = fsdp_config['load_planner']
fsdp_config['load_planner'] = build_load_planner(load_planner_name)
load_planners = fsdp_config['load_planner'].items()
if len(load_planners) > 1:
raise ValueError(
'Only one load planner can be specified in the config.',
)
load_planner_name, load_planner_config = load_planners[0]
fsdp_config['load_planner'] = build_load_planner(
load_planner_name,
**load_planner_config,
)

if 'save_planner' in fsdp_config:
save_planner_name = fsdp_config['save_planner']
fsdp_config['save_planner'] = build_save_planner(save_planner_name)
save_planners = fsdp_config['save_planner'].items()
if len(save_planners) > 1:
raise ValueError(
'Only one save planner can be specified in the config.',
)
save_planner_name, save_planner_config = save_planners[0]
fsdp_config['save_planner'] = build_save_planner(
save_planner_name,
**save_planner_config,
)

eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders
icl_tasks_config = train_cfg.icl_tasks or train_cfg.icl_tasks_str
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def build_icl_data_and_gauntlet(
return icl_evaluators, logger_keys, eval_gauntlet_cb


def build_load_planner(name: str) -> LoadPlanner:
def build_load_planner(name: str, **kwargs: Any) -> LoadPlanner:
"""Builds a load planner from the registry.
Args:
Expand All @@ -203,11 +203,11 @@ def build_load_planner(name: str) -> LoadPlanner:
partial_function=True,
pre_validation_function=LoadPlanner,
post_validation_function=None,
kwargs={},
kwargs=kwargs,
)


def build_save_planner(name: str) -> SavePlanner:
def build_save_planner(name: str, **kwargs: Any) -> SavePlanner:
"""Builds a save planner from the registry.
Args:
Expand All @@ -222,7 +222,7 @@ def build_save_planner(name: str) -> SavePlanner:
partial_function=True,
pre_validation_function=SavePlanner,
post_validation_function=None,
kwargs={},
kwargs=kwargs,
)


Expand Down

0 comments on commit 1ba025f

Please sign in to comment.