-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Stable Diffusion #828
Add Stable Diffusion #828
Conversation
/cc @divamgupta @innat |
So cool, can't wait to have SD ready and in the package... |
As this is the first materialized example where we are handling NLP+VISION I don't know if it is interesting to re-introduce some of our related topics like: Quoting @chenmoneygithub
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor changes
keras_cv/models/generative/stable_diffusion/stable_diffusion.py
Outdated
Show resolved
Hide resolved
@@ -0,0 +1 @@ | |||
from tensorflow_addons.layers import GroupNormalization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to remove this and duplicate/porting/refactor it in keras-cv or keras repo (#74)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ianstenbit is working on this for now, it’ll live in core Keras - temporarily as a CV internal API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excited to see this!
A general comment -- are we getting the weights from somewhere, or we will provide a training script later?
|
||
|
||
def gelu(x): | ||
tanh_res = keras.activations.tanh(x * 0.7978845608 * (1 + 0.044715 * (x**2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numbers seem pretty magical :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a polynomial approximation of gelu. Generally speaking there are a number of constants harcoded in the code, because the code is non-configurable beyond the arguments of StableDiffusion
and text_to_image
. Its only function is to load the original weights and match the original numerics.
As of now they’re ported. If I can setup LAOIN-5B (hoping to in the next few months…) we can retrain; LAOIN-5B was the dataset this model was original trained on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to add an entry to keras_cv/models/init.py!
Are we going to target recent OpenCLIP releases? |
@@ -0,0 +1,138 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure that we want to handle all these numpy objects and numpy TF experimental wrapped TF ops instead of handling directly Tensor and TF ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am asking this also for coherence:
grep -r "import numpy" * --exclude="*test.py" --exclude-dir="*examples*" --exclude-dir="*test*" --exclude-dir="*benchmarks*"
keras_cv/layers/preprocessing/random_rotation.py:import numpy as np
keras_cv/models/object_detection/__internal__.py:import numpy as np
keras_cv/models/object_detection/retina_net/retina_net.py:import numpy as np
Then also in these few cases it was used only for np.log
and np.pi
.
(I don't know if there was a real cause where we could not use pi from python math
and/or tf.math.log
)
|
||
|
||
class SimpleTokenizer: | ||
def __init__(self, bpe_path: str = default_bpe()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here's the bug causing your lazy loading to not work:
bpe_path=None
bpe_path = bpe_path or default_bpe()
@@ -126,13 +141,3 @@ def get_initial_parameters(self, timesteps, batch_size, seed=None): | |||
(batch_size, self.img_height // 8, self.img_width // 8, 4), seed=seed | |||
) | |||
return noise, alphas, alphas_prev | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to add this to examples/models/stable_diffusion/
right now they’re basically glorified debugging demos; but we could render them someday down the line with proper polish
keras_cv/models/generative/stable_diffusion/__internal__/layers/group_normalization.py
Outdated
Show resolved
Hide resolved
I added a docstring, removed the TFA dependency, added a code example, and disabled the golden value test so that it doesn't run on CI (which would be overly expensive -- we don't want to download the weights on CI). I believe we're ready to go, or pretty close. |
/gcbrun |
There will be a lot of transformer based models released in the future. Should there by separate modules for transformers? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reviewed the full PR, and this LGTM. I have also pulled locally and played with the implementation; great job on the readability and organization.
@fchollet you also need to add |
Somehow I'm able to run |
/gcbrun |
Adding a dependency to our GCB cluster; one monent... |
/gcbrun |
* Add Stable Diffusion * Further simplification; add files * Update imports * Further minor simplications * Style fixes * Further beautification. The code is now 500 LOC excluding the tokenizer and the constants file. * Readability improvements * Improve generation loop * Simplify generation code * Fix bpe_path * Add init imports * Minor style fixes * Remove unnecessary dependencies and add file headers * Add test * Add example * Add group normalization layer * Disable test so it doesn't run on CI * Fix code style * Update docs. * Format imports * Remove unused import * Add file header * Add more copyright notices * Add last copyright notice * Add regex requirement * Hopefully last copyright notice
This is a working Stable Diffusion text to image model. It reuses a bunch of code from Divam's original implementation. All top-level models are full rewrites as Functional models.
All in all it's about 600 LOC. We can probably go below we further refactoring. The UNet is too declarative right now, it could be refactored to be a lot more concise.
We need to:
DiffusionModel
) to be as elegant as possibleconstants.py