Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Design suggestion: remove forward from list of methods to override #838

Closed
elistevens opened this issue Feb 14, 2020 · 18 comments
Closed

Design suggestion: remove forward from list of methods to override #838

elistevens opened this issue Feb 14, 2020 · 18 comments
Labels
discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on waiting on author Waiting on user action, correction, or update

Comments

@elistevens
Copy link

🚀 Feature (?)

This is more a philosophical design suggestion than a feature request.

I think that the presentation of LightningModule as a torch.Module-plus-features encourages early experiment designs that don't refactor nicely as the projects using it grow.
I also think that calling self.forward directly is a torch anti-pattern, and should not be encouraged.
I'd like the official docs to suggest using a self.my_model = MyModel(...) in __init__ and y = self.my_model(x) in training_step etc.

Motivation

I think that most non-research uses of lightning are going to require that the environment the model is trained in be separable from the model itself. This is most obvious when considering the infrastructure needed to load training data vs. production inference data; you're not going to want to drag along all of the libraries needed to connect to a database, decompress data, etc. in the production environment.

To do so, I'd need to be able to from some.other.package import MyModel and then self.my_model = MyModel(...) in __init__. As long as some.other.package doesn't have extra dependencies, I can ship my production model and weights to production without needing everything else that lightning, etc. depends on.

By suggesting that users have the lightning subclass be the model, the set of packages that need to be present in production goes up quite a bit (speaking from experience, the pip version management becomes painful).

Another thing that this makes unclear, then, is what is actually happening when training_step gets called. The suggestion "Normally you'd call self.forward() from your training_step() method." implies that self.training_step is happening inside of a self.__call__ since torch.nn.Module.forward isn't supposed to be called directly (since it's __call__ that handles hooks, etc.), but that doesn't actually seem to be the case. Unless I'm missing something, this really feels like misuse of the torch API.

By making it clear that your LightningModule subclass should have an instance of your model as an attribute, not be the model, all of the above gets cleared up quite a bit.

Pitch

I think it's a lot cleaner and clearer to say "Normally you'd call y = self.my_model(x) from your training_step() method." and remove any suggestion of overriding self.forward() from the documentation (and I'd in fact make the default implementation of forward raise a YouAreDoingItWrongException).

As I said earlier, I think that projects that mix training and model code in the same class are going to have a difficult time refactoring things later on, and I think that the perceived simplicity early on is a mastery trap. Anyone familiar with PyTorch isn't going to have a problem defining a separate model class.

Alternatives

Note that I don't think there's anything preventing me from implementing models the way I think is proper right now, but I'm currently doing an investigation into if we can use lightning for more projects in the organization, and I'd really rather not having to try and educate users to ignore the docs and do it the self.my_model() way instead.

At the very least, changing the documentation to say "Normally you'd call y = self(x) from your training_step() method." makes sure that hooks, etc. get called as expected.

Additional context

Now, I will fully admit that I haven't dug into lightning a ton yet, so it's possible that I'm missing something that will change my understanding/perception of things. If that's the case, I think it should be articulated more clearly.

Thanks for reading.

@elistevens elistevens added feature Is an improvement or enhancement help wanted Open to be worked on labels Feb 14, 2020
@djbyrne
Copy link
Contributor

djbyrne commented Feb 14, 2020

I recently started using Lightning for a project I have been working on and I needed to import the model from a seperate module like you stated @elistevens.

In my Lightning init i just instantiate my external model and override the forward to return mymode(x). This works fine, however, I agree that it might be better to have the model as an attribute as opposed to Lightning being the model.

@ghost
Copy link

ghost commented Feb 19, 2020

This would also help with more complicated research projects that involve multiple models (autoencoders or GANs, for example) and make things a lot more flexible and pythonic, "pytorchic."

@jeremyjordan
Copy link
Contributor

@darwinkim I agree with @elistevens that it will be useful to be able to "extract" a more lightweight to ship to production. However, can you provide an example as to how this would help with more complicated research projects that involve multiple models?

There's a GAN example here which shows how you can cleanly incorporate multiple models.

@jeremyjordan
Copy link
Contributor

I wonder if there's a way we could expose a jit_export functionality which JIT compiles the model and extracts out only that which is necessary for serving inference. The LightningModel would contain everything needed for training and it could export a minimal model for inference as a post-training artifact.

@ghost
Copy link

ghost commented Feb 19, 2020

@jeremyjordan
https://arxiv.org/abs/1703.00848
Involves training six networks as two autoencoders and two GANs on two datasets

@Borda Borda added discussion In a discussion stage and removed help wanted Open to be worked on labels Feb 21, 2020
@awaelchli
Copy link
Member

just a note: PR #1211 promotes the use of self(...) instead of self.forward in examples and docs.

@Borda
Copy link
Member

Borda commented Apr 5, 2020

@williamFalcon @PyTorchLightning/core-contributors ^^

@stale
Copy link

stale bot commented Jun 4, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jun 4, 2020
@Borda Borda added help wanted Open to be worked on Important and removed won't fix This will not be worked on labels Jun 10, 2020
@stale
Copy link

stale bot commented Aug 10, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Aug 10, 2020
@elistevens
Copy link
Author

@williamFalcon asked me to revisit this, so I'm adding some more thoughts. PR #1211 fixed the issue of suggesting that users call .forward() directly, but there's another layer to what I'm trying to suggest.

Essentially, I'm wanting to clearly and cleanly separate concerns, and have that clean separation be suggested by the documentation. From a OOP perspective, the documentation suggests that the training loop object and the model object be the same object and that mixes two separate concepts.

Put another way, if you were going to be using a stock model from torchvision, you wouldn't have class MyModel(pl.LightningModule, ResNet): you'd have class MyModel(pl.LightningModule): def __init__(self): self.model = ResNet(). The training loop and the model would be separate python objects, and you could do things with the model such that you'd never know it was trained with lightning. For example, save/load the weights, or export it to onnx, etc.

How it's suggested now, it becomes much harder to pull out my model and use it in some other context (like a different training loop). I typically try to avoid libraries with that kind of lock-in.

@stale stale bot removed the won't fix This will not be worked on label Aug 10, 2020
@williamFalcon
Copy link
Contributor

williamFalcon commented Aug 10, 2020

Thanks for adding more details!

I use lightning a lot the way you describe. What gives you the impression that you can’t use it this way?

Is there a better way to show this in the docs or examples?

First, take a look at all the bolts models. Most models in bolts have that pattern.

Second, when you are done you can load the full thing and pull out whatever parts are interesting to you (ie: just the encoder of a GAN), or make the forward use only the encoder.

But yeah, you can always drop a model into a lightningModule and use the lightningmodule purely as training loops for the model.

https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol

Finally, we can make some example or write in the docs what you’re more clearly looking for if you’d prefer:

class ClassificationTask(LightningModule):

def __init__(self, model):
    self.model = model

 def training_step(...):

 def validation_step(...):

 def test_step(...):

 def configure_optimizers(...):

model = Resnet50()
loops = ClassificationTask(model)
trainer = Trainer(...)
trainer.fit(model, train_dataloader, val_dataloader)

In fact, we can add a new section to bolts with these prebuilt loops. Classification loops, fine-tuning loop, etc...

@williamFalcon
Copy link
Contributor

Ok, added the following to the docs to clarify this particular use case of a lightning module.

image

@awaelchli
Copy link
Member

Can we at least raise a NotImplementedError like PyTorch does? I only just now noticed that in the current version, LightningModule actually implements forward for you to return None. Why is that?
My expectation was that LightningModule behaves like a nn.Module outside the context of PL.

@williamFalcon
Copy link
Contributor

williamFalcon commented Aug 22, 2020

it does but forward is not required...

we want to separate training from inference. in training you use the __step methods.

if your model also happens to do inference, then it should implement forward.

this makes a clean separation between training scripts purely and models.

this removal also enables tasks which weren’t possible before.

@awaelchli
Copy link
Member

awaelchli commented Aug 22, 2020

All of this is clear. No problem with that. If you don't use forward all is good.
I suggest to raise NotImplementedError if you use self.forward anywhere, instead of just returning None.

@awaelchli
Copy link
Member

awaelchli commented Aug 22, 2020

class Lightning(LightningModule):
    pass


class Torch(nn.Module):
    pass


lightning_model = Lightning()
print(lightning_model(torch.rand(2, 2)))  # does not raise, returns None, why?

torch_model = Torch()
print(torch_model(torch.rand(2, 2)))  # raises NotImplementedError, good!

@tchaton
Copy link
Contributor

tchaton commented Mar 9, 2021

Hey,

Any updates on this issue ?

Best,
T.C

@kaushikb11 kaushikb11 added the waiting on author Waiting on user action, correction, or update label Apr 20, 2021
@awaelchli
Copy link
Member

all items here were addressed a while ago. we can close

  • forward is optional to implement
  • our docs favor self() instead of .forward()
  • if someone calls an unimplemented forward manually, it will call into super().forward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement help wanted Open to be worked on waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

8 participants