Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Model Parallelism Through FX #1933

Conversation

zhenglongjiepheonix
Copy link
Contributor

@zhenglongjiepheonix zhenglongjiepheonix commented Jul 1, 2024

What does this PR do?

This PR tries to add an automatic parallelization backend for torch dynamo, which takes the dynamo-captured fx graph, runs a few passes to automatically identify parts that can be parallelized and transforms the graph into its parallelized version. For simplicity it focuses on models supporting dynamo tracing in transformers library right now and might not support custom models because of the tricky parts in parallel pattern matching.

For now it only supports parallelization of linears in the graph, in the context of transformers they would be attention layers and mlp layers, with the following milestones left:

  • support tensor parallelism on loss layers(need models support training-mode trace)
  • support weights loading from disk/hub
  • support fusion of parallel linears(qkv fusion, mlp fusion)
  • try supporting sequence parallelism

Please feel free to review and provide suggestions even if it's still in progress and not covering all features.
According to @michaelbenayoun , we should try merging the first version and iterations will be coming in following PRs

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

optimum/fx/parallelization/__init__.py Show resolved Hide resolved
optimum/fx/parallelization/__init__.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/core.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/core.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/core.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/passes.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/passes.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/utils.py Show resolved Hide resolved
optimum/fx/parallelization/utils.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super cool!

rank = dist.get_rank(group = group)

tensor = tensor.contiguous()
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

size = tensor.size()
assert size[split_dim] % world_size == 0
tensors = torch.split(tensor, size[split_dim] // world_size, dim = split_dim)
tensor = tensors[rank].contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why contiguous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensors after split may not be contiguous, I think it's better be contiguous

self.bias.zero_()

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = differentiable_identity(input, self.process_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need an identity here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to take care of gradient reduce in backward

self.clear_marker_per_node(node)


class ParallelLinearAnnotatePass(AnalyzeBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get the name parallel here? Isn't it more like successive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it actually means annotate some layers to be their parallel counterparts

optimum/fx/parallelization/passes.py Outdated Show resolved Hide resolved
Comment on lines +67 to +73
tensors = gather_at_main_process(tensor=logits, group=tp_group, rank=rank, world_size=world_size)

# check results at main worker process
if rank == 0:
assert len(tensors) == world_size
for i in range(1, world_size):
torch.testing.assert_close(tensors[i - 1].cpu(), tensors[i].cpu(), rtol=1e-4, atol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should probably be checked on all ranks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check at the main process should be enough, because it gathers results from other ranks at main process and does comparison

move_model_to_device(model, device=device)
initialize_parameter_mapping(model, ctx=ctx)

model = torch.compile(model, fullgraph=True, backend=partial(parallelize_backend, ctx=ctx, config=cfg))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we compose with inductor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite confident to say now, but at least it won't be able single graph

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that is the hope for the future.

optimum/fx/parallelization/core.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/core.py Outdated Show resolved Hide resolved
optimum/fx/parallelization/parallel_layers/embedding.py Outdated Show resolved Hide resolved
Comment on lines 61 to 62
model (Union[torch.nn.Module, str]):
Model to parallelize, could either be a module or a model id in huggingface space.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model (Union[torch.nn.Module, str]):
Model to parallelize, could either be a module or a model id in huggingface space.
model (Union[torch.nn.Module, str]):
Model to parallelize, could either be a module or a model id on the Hugging Face Hub.

Model to parallelize, could either be a module or a model id in huggingface space.
parallel_ctx (ParallelExecutionCtx):
Parallel execution context containing process groups the current process belongs to.
model_args (additional postional arguments, optional):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_args (additional postional arguments, optional):
*model_args (Any):

Should we add also model_kwargs?

Whether to use local files only, will avoid downloading from remote if set to `True`.
skip_load_weights (`bool`, defaults to `False`):
Whether to skip loading weights from disk to model.
kwargs (additional keyword arguments, optional):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwargs (additional keyword arguments, optional):
**kwargs (Dict[str, Any]):

Comment on lines 69 to 75
cache_dir (`Optional[str]`, defaults to `None`):
Cache directory to store downloaded weights. Defaults to None.
local_files_only (`bool`, defaults to `False`):
Whether to use local files only, will avoid downloading from remote if set to `True`.
skip_load_weights (`bool`, defaults to `False`):
Whether to skip loading weights from disk to model.
kwargs (additional keyword arguments, optional):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We provide a lot of things here.
IMO we should simplify that. Most of these arguments come from the from_pretrained method.
So I would gather them as one keyword argument: model_kwargs.

Comment on lines 79 to 82
for k, v in kwargs.items():
if k in parallel_config.__dict__:
setattr(parallel_config, k, v)
kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also iterate on a copy of kwargs and pop elements as follows:

Suggested change
for k, v in kwargs.items():
if k in parallel_config.__dict__:
setattr(parallel_config, k, v)
kwargs = {k: v for k, v in kwargs.items() if k not in parallel_config.__dict__}
for k, v in dict(kwargs).items():
if k in parallel_config.__dict__:
setattr(parallel_config, k, v)
kwargs.pop(k)

else:
hf_folder = model

# should be able to load config using only local files
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because you only allowed patterns to be safetensors and bin files, and config is a json.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I move all the dowload logic including config and index files into download_model_from_hf

Comment on lines 112 to 116
use_safetensors = False
for pattern in allow_patterns:
if len(glob.glob(os.path.join(hf_folder, pattern))) > 0:
use_safetensors = pattern == "*.safetensors"
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be simplified.

Comment on lines 117 to 137
index_path = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME)
if os.path.isfile(index_path):
with open(index_path) as f:
index_dict = json.load(f)
parallel_ctx.weight_map = {k: os.path.join(hf_folder, v) for k, v in index_dict["weight_map"].items()}
weight_files = glob.glob(os.path.join(hf_folder, "*.safetensors" if use_safetensors else "*.bin"))
if not use_safetensors:
weight_map = parallel_ctx.weight_map if parallel_ctx.weight_map else {}
convert_bin_to_safetensors(model, cache_dir, weight_files, weight_map)
parallel_ctx.weight_map = weight_map

# try directly construct weight_map from weight files, should have safetensors file on disk in any case
if not parallel_ctx.weight_map:
from safetensors import safe_open

weight_map, weight_files = {}, glob.glob(os.path.join(hf_folder, "*.safetensors"))
for weight_file in weight_files:
with safe_open(filename=weight_file, framework="pt") as f:
for key in f.keys():
weight_map[key] = weight_file
parallel_ctx.weight_map = weight_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think overall it can be simplified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I move the logic into utils so it looks cleaner in api, but the logic itself is indeed complex, because we need to take care of situations where a local directory is passed, then the only thing we can do is to try to peek inside the folder and see if there are safetensors/bin files, and if there are only bin files, we need to convert them into safetensors, and if there is an index file, we load the weight_map directly from it, otherwise we try scanning all the weight files in the folder and assemble a weight_map out of them.

@zhenglongjiepheonix zhenglongjiepheonix changed the title [WIP] Automatic Model Parallelism Through FX Automatic Model Parallelism Through FX Jul 24, 2024
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I am getting everything because it is a very long and complex PR but LGTM!
Let's iterate on smaller PRs from now on.
Thanks @zhenglongjiepheonix !

@zhenglongjiepheonix
Copy link
Contributor Author

Merge this as experimental first version, more fix-ups and features coming in following PRs!

@zhenglongjiepheonix zhenglongjiepheonix merged commit 5eaf91b into huggingface:main Aug 12, 2024
46 of 47 checks passed
options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/
env:
NCCL_DEBUG: INFO
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhenglongjiepheonix @michaelbenayoun is HF_TOKEN used for the tests (can't see where) or can we remove ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not used, you can remove it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in #2061

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants