-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix SegFormer Presets #2222
Fix SegFormer Presets #2222
Conversation
Thank you for the PR Neel! |
} | ||
if preset in aliases: | ||
preset = aliases[preset] | ||
backbone = MiTBackbone.from_preset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we just chain to super here? And add the mit_b preset to backbone_presets
?
To hit this branch
keras-cv/keras_cv/models/task.py
Lines 134 to 141 in f64fba0
metadata = cls.presets[preset] | |
# Check if preset is backbone-only model | |
if preset in cls.backbone_presets: | |
backbone_cls = keras.saving.get_registered_object( | |
metadata["class_name"] | |
) | |
backbone = backbone_cls.from_preset(preset, load_weights) | |
return cls(backbone, **kwargs) |
That would have the advantage of being totally uniform with how other tasks to this, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, slight modifications needed to be made to the task logic here in order to pass input_shape to the backbone's from_preset
method. I've managed to make the uniform flow work, thanks!
We may need to revisit this at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
* Fix segformer presets * Add extra SegFormer tests * Change flow, add backbone presets to SegFormer presets * Fix small formatting issue
This PR enables proper functionality for SegFormer presets and adds basic tests.
The following behavior is created here:
SegFormer.from_preset("segformer_b0") -> missing required argument num_clases
SegFormer.from_preset("segformer_b0", num_classes=2) -> ok
SegFormer.from_preset("mit_b0", num_classes=2) -> the exact same