Skip to content

Commit

Permalink
Add cloud upload to checkpoint conversion script (#151)
Browse files Browse the repository at this point in the history
* actually upload the file

* fix
  • Loading branch information
dakinggg authored May 17, 2023
1 parent afeb7a6 commit 93e3290
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions scripts/misc/convert_examples_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from composer.utils import (get_file, maybe_create_object_store_from_uri,
safe_torch_load)
parse_uri, safe_torch_load)

from llmfoundry.models.mpt.configuration_mpt import (attn_config_defaults,
init_config_defaults)
Expand Down Expand Up @@ -79,11 +79,12 @@ def convert_examples_ckpt(
local_ckpt_path = Path(tmp_dir.name) / 'local-composer-checkpoint.pt'

# create object store if output_path
_, _, local_folder_path = parse_uri(output_path)
object_store = maybe_create_object_store_from_uri(str(output_path))
if object_store is not None:
local_output_path = tempfile.TemporaryDirectory().name
else:
local_output_path = output_path
local_output_path = local_folder_path

# create folder
os.makedirs(local_output_path)
Expand Down Expand Up @@ -180,8 +181,15 @@ def convert_examples_ckpt(
param_idx] = param_name

# Save weights
torch.save(composer_state_dict,
Path(local_output_path) / checkpoint_path.split('/')[-1])
file_path = str(Path(local_output_path) / checkpoint_path.split('/')[-1])
print(f'Writing converted output to {file_path}')
torch.save(composer_state_dict, file_path)

if object_store is not None:
remote_file_path = os.path.join(local_folder_path,
checkpoint_path.split('/')[-1])
print(f'Uploading from {file_path} to {remote_file_path}')
object_store.upload_object(remote_file_path, file_path)


def main(args: Namespace) -> None:
Expand Down

0 comments on commit 93e3290

Please sign in to comment.