Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 000 propagate imagenet dataset params #1368

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ImagesClassificationPrediction,
ClassificationPrediction,
)
from torch.nn.functional import softmax
from super_gradients.training.utils.utils import generate_batch
from super_gradients.training.utils.media.video import load_video, includes_video_extension
from super_gradients.training.utils.media.image import ImageSource, check_image_typing
Expand Down Expand Up @@ -410,17 +409,17 @@ def __init__(
def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[ClassificationPrediction]:
"""Decode the model output

:param model_output: Direct output of the model, without any post-processing.
:param model_output: Direct output of the model, without any post-processing. Tensor of shape [B, C]
:param model_input: Model input (i.e. images after preprocessing).
:return: Predicted Bboxes.
"""
confidence_predictions, classifier_predictions = torch.max(model_output, 1)
pred_scores, pred_labels = torch.max(model_output.softmax(dim=1), 1)

classifier_predictions = classifier_predictions.detach().cpu().numpy()
confidence_predictions = softmax(confidence_predictions).detach().cpu().numpy()
pred_labels = pred_labels.detach().cpu().numpy() # [B,1]
pred_scores = pred_scores.detach().cpu().numpy() # [B,1]

predictions = list()
for prediction, confidence, image_input in zip(classifier_predictions, confidence_predictions, model_input):
for prediction, confidence, image_input in zip(pred_labels, pred_scores, model_input):
predictions.append(ClassificationPrediction(confidence=float(confidence), label=int(prediction), image_shape=image_input.shape))
return predictions

Expand Down
23 changes: 19 additions & 4 deletions src/super_gradients/training/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,22 @@ def default_dekr_coco_processing_params() -> dict:
return params


def default_resnet_imagenet_processing_params() -> dict:
def default_imagenet_processing_params() -> dict:
"""Processing parameters commonly used for training resnet on Imagenet dataset."""
image_processor = ComposeProcessing(
[Resize(size=256), CenterCrop(size=224), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), StandardizeImage(), ImagePermute()]
[Resize(size=256), CenterCrop(size=224), StandardizeImage(), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ImagePermute()]
)
params = dict(
class_names=IMAGENET_CLASSES,
image_processor=image_processor,
)
return params


def default_vit_imagenet_processing_params() -> dict:
"""Processing parameters used by ViT for training resnet on Imagenet dataset."""
image_processor = ComposeProcessing(
[Resize(size=256), CenterCrop(size=224), StandardizeImage(), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ImagePermute()]
)
params = dict(
class_names=IMAGENET_CLASSES,
Expand All @@ -668,7 +680,10 @@ def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -
if pretrained_weights == "coco_pose" and model_name in ("dekr_w32_no_dc", "dekr_custom"):
return default_dekr_coco_processing_params()

if pretrained_weights == "imagenet" and model_name == "resnet18":
return default_resnet_imagenet_processing_params()
if pretrained_weights == "imagenet" and model_name in {"vit_base", "vit_large", "vit_huge"}:
return default_vit_imagenet_processing_params()

if pretrained_weights == "imagenet":
return default_imagenet_processing_params()

return dict()
11 changes: 6 additions & 5 deletions tests/unit_tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ def setUp(self) -> None:
]

def test_classification_models(self):
model = models.get(Models.RESNET18, pretrained_weights="imagenet")

with tempfile.TemporaryDirectory() as tmp_dirname:
predictions = model.predict(self.images)
predictions.show()
predictions.save(output_folder=tmp_dirname)
for model_name in {Models.RESNET18, Models.EFFICIENTNET_B0, Models.MOBILENET_V2, Models.REGNETY200}:
model = models.get(model_name, pretrained_weights="imagenet")

predictions = model.predict(self.images)
predictions.show()
predictions.save(output_folder=tmp_dirname)

def test_pose_estimation_models(self):
model = models.get(Models.DEKR_W32_NO_DC, pretrained_weights="coco_pose")
Expand Down