Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create trainer class(es) from main_ds.py #20

Closed
wants to merge 7 commits into from
Closed

Conversation

JamesKunstle
Copy link
Contributor

@JamesKunstle JamesKunstle commented Jun 14, 2024

Reorganizes the training code currently in main_ds.py into objects.

Adds:

  • DeepSpeedTrainer
  • DeepSpeedModelWrapper
  • MultipackDataWrapper

Subjects for review:

  1. Inputs to constructors of objects are exhaustive intentionally- I wanna be smart about how we design it. How do we want to wrap these up into configuration objects? Do we want to reuse the config objects from the config objects that Oleg's PR put together? Do we want our own?
  2. How do we expect for these objects to be consumed? These are pretty tightly coupled to one another, but one could use a different *DataWrapper object for a different sample strategy, so they're not totally coupled.
  3. How else could we break these objects up?
  4. Exhaustively define the parameters so contributors / users know which to attend to vs. which to ignore.

Signed-off-by: James Kunstle <jkunstle@redhat.com>
Signed-off-by: James Kunstle <jkunstle@redhat.com>
removes breaking `_arg.copy()` calls

Signed-off-by: James Kunstle <jkunstle@redhat.com>
Signed-off-by: James Kunstle <jkunstle@redhat.com>
@JamesKunstle JamesKunstle marked this pull request as ready for review June 19, 2024 19:31
@JamesKunstle JamesKunstle changed the title TEMP create trainer class(es) from main_ds.py create trainer class(es) from main_ds.py Jun 19, 2024
@JamesKunstle
Copy link
Contributor Author

From mistral-finetune:

The goal of this repository is to provide a simple, guided entrypoint to finetune Mistral models. As such, it is fairly opinionated (especially around data formatting) and does not aim at being exhaustive across multiple model architectures or hardware types. For more generic approaches, you can check out some other great projects like [torchtune](https://pytorch.org/torchtune/stable/overview.html).

I think we should adopt the same design philosophy. We should aim to accept a subset of model architectures that we've verified (e.g. granite via dolomite, llama, mistral) with opinionated optimizations on data sampling throughput. e.g. because of Multipack, we have to do distributed loss reduce-sum.

Then we can train in the context of the distributed training framework (DeepSpeed vs. FSDP) with a specific trainer class.

Signed-off-by: James Kunstle <jkunstle@redhat.com>
@@ -112,9 +113,9 @@ def __init__(
):
# NOTE: this is exhausive. In review, let's decide how to split this up into config groups.
self.model_name_or_path: str = model_name_or_path
self.data_loader: any = data_loader
self.data_loader: Any = data_loader
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does setting these as Any make it so that subsequent references are typed? Or will they appear as any.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because the variables seem naked without a type

model_name_or_path: str,
data_loader: Any, # TODO: type the DataLoader obj.
output_dir: str,
tokenizer: Any, # TODO: type the tokenizer obj.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is PreTrainedTokenizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

@fabianlim
Copy link
Contributor

Just one comment maybe file should be called trainer_ds.py?

distributed_init(_args)

dataw = MultipackDataWrapper(
model_name_or_path=_args.model_name_or_path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesnt seem very intuitive that data pre requires the model name. It seems more intuitive to pass the tokenizer. But this means the model has to be loaded outside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a circular dependency that I'm glad you noticed. We can load the tokenizer independently, which I think breaks the cycle. The model setup step requires information calculated in the DataWrapper class. We could split model and tokenizer loading out from DeepSpeedModelWrapper, which I think woudl break up the parts better. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i sugggest to do any automodel or dolomite loading outside of the dataset and deepspeed preperation functions, then just pass it in to those functions as needed.

effective_batch_size=_args.effective_batch_size,
max_batch_len=_args.max_batch_len,
world_size= int(os.getenv("WORLD_SIZE")),
is_padding_free=_args.is_granite,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this introduces some coupling into the DataWrapper, since it actually doesnt really need to know if the model is granite or not. it just needs some dependency injection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DataWrapper does need to know about whether the model is padding free because the MultipackDataSampler uses that information for batch allocation w.r.t padding. Do you intuit that there's a better split for this class? The sample could be initialized elsewhere.

Copy link
Contributor

@fabianlim fabianlim Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i misunderstood. the flag is is_padding_free.

But bear in mind that it is typically not canon (at least in the huggingface world) that a dataloader is prepared for a trainer, usually a dataset is passed to a trainer.

Aslo its not canon that a prepared model is passed to the trainer (at least in huggingface world). Usually a unprepared model (i.e. not yet wrapped with deepspeed) is passed to the trainer and it does it

Hence what im saying is that I feel you really only need one class DeepSpeedTrainer.

  • and then MultipackDataWrapper and DeepSpeedModelWrapper can be converted to internal method calls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ okay @fabianlim that makes a lot of sense.

data_path=_args.data_path,
effective_batch_size=_args.effective_batch_size,
max_batch_len=_args.max_batch_len,
world_size= int(os.getenv("WORLD_SIZE")),
Copy link
Contributor

@fabianlim fabianlim Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general its better to get this as torch.distributed.get_world_size(). But this requires initializing the torch.distributed first. Right now we rely on deepspeed.initialize to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we initialize distributed a bit earlier, do you suggest swapping it out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, there isn't a torch.distribted.get_local_rank()- is it typical to just use

local_rank = torch.distributed.get_rank() // torch.distributed.get_world_size()

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesKunstle No that is not reliable. Actually torch added a new method get_node_local_rank but that is only available in bleeding edge https://github.com/pytorch/pytorch/pull/123992/files, so we are out of luck. My suggestion is:

  • follow the official implementation above.
  • wrap this logic in function or patch it to torch.distributed (the latter will be better to keep the api future proof)


distributed_init(_args)

dataw = MultipackDataWrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that you dont really need an object for this. All that is really needed is a dataloader being prepared.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very valid feedback. I wrapped these things into a pseudo-dataclass because they all related downstream, and the calculations required each other in order. It's more a Wrapper than a proper class in that way.

f"samples_per_gpu: {dataw.samples_per_gpu}\033[0m"
)

modelw = DeepSpeedModelWrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here, all you need is a model being prepared.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before, these data prep and model prep steps could be functional rather than class-based. That'd be more similar to the mistral-finetune repo that Aldo pointed me toward.


self.packing_max_batch_len, self.grad_accum = (
find_packing_max_batch_len_and_grad_accum(
num_gpus=torch.distributed.get_world_size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have world_size above so better to just get it from there.

self.data_loader = setup_dataloader(
self.dataset,
self.tokenizer.pad_token_id,
num_workers=8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do not hardcode this, let this be an input to the function/wrapper

@JamesKunstle
Copy link
Contributor Author

@fabianlim, @aldopareja and I are going to reorganize this effort. Based on both of your comments it seems like there's a more lightweight, canonical, and less-coupled way of doing this. I tried to minimize rewrite with this PR and slice stuff how it might be most reusable but it seems that there's a better way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants