Skip to content

Commit

Permalink
Feature/sg 000 propagate imagenet dataset params (#1368)
Browse files Browse the repository at this point in the history
* Propagate default dataset processing params for other classification models

* Fix bug in predict pipeline (Softmax was computed along batch dimension AFTER taking max along classes dimension)

* Added more classification models to test
  • Loading branch information
BloodAxe committed Aug 11, 2023
1 parent fb01f1f commit b6499b6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
11 changes: 5 additions & 6 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ImagesClassificationPrediction,
ClassificationPrediction,
)
from torch.nn.functional import softmax
from super_gradients.training.utils.utils import generate_batch
from super_gradients.training.utils.media.video import load_video, includes_video_extension
from super_gradients.training.utils.media.image import ImageSource, check_image_typing
Expand Down Expand Up @@ -410,17 +409,17 @@ def __init__(
def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[ClassificationPrediction]:
"""Decode the model output
:param model_output: Direct output of the model, without any post-processing.
:param model_output: Direct output of the model, without any post-processing. Tensor of shape [B, C]
:param model_input: Model input (i.e. images after preprocessing).
:return: Predicted Bboxes.
"""
confidence_predictions, classifier_predictions = torch.max(model_output, 1)
pred_scores, pred_labels = torch.max(model_output.softmax(dim=1), 1)

classifier_predictions = classifier_predictions.detach().cpu().numpy()
confidence_predictions = softmax(confidence_predictions).detach().cpu().numpy()
pred_labels = pred_labels.detach().cpu().numpy() # [B,1]
pred_scores = pred_scores.detach().cpu().numpy() # [B,1]

predictions = list()
for prediction, confidence, image_input in zip(classifier_predictions, confidence_predictions, model_input):
for prediction, confidence, image_input in zip(pred_labels, pred_scores, model_input):
predictions.append(ClassificationPrediction(confidence=float(confidence), label=int(prediction), image_shape=image_input.shape))
return predictions

Expand Down
23 changes: 19 additions & 4 deletions src/super_gradients/training/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,22 @@ def default_dekr_coco_processing_params() -> dict:
return params


def default_resnet_imagenet_processing_params() -> dict:
def default_imagenet_processing_params() -> dict:
"""Processing parameters commonly used for training resnet on Imagenet dataset."""
image_processor = ComposeProcessing(
[Resize(size=256), CenterCrop(size=224), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), StandardizeImage(), ImagePermute()]
[Resize(size=256), CenterCrop(size=224), StandardizeImage(), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ImagePermute()]
)
params = dict(
class_names=IMAGENET_CLASSES,
image_processor=image_processor,
)
return params


def default_vit_imagenet_processing_params() -> dict:
"""Processing parameters used by ViT for training resnet on Imagenet dataset."""
image_processor = ComposeProcessing(
[Resize(size=256), CenterCrop(size=224), StandardizeImage(), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ImagePermute()]
)
params = dict(
class_names=IMAGENET_CLASSES,
Expand All @@ -668,7 +680,10 @@ def get_pretrained_processing_params(model_name: str, pretrained_weights: str) -
if pretrained_weights == "coco_pose" and model_name in ("dekr_w32_no_dc", "dekr_custom"):
return default_dekr_coco_processing_params()

if pretrained_weights == "imagenet" and model_name == "resnet18":
return default_resnet_imagenet_processing_params()
if pretrained_weights == "imagenet" and model_name in {"vit_base", "vit_large", "vit_huge"}:
return default_vit_imagenet_processing_params()

if pretrained_weights == "imagenet":
return default_imagenet_processing_params()

return dict()
11 changes: 6 additions & 5 deletions tests/unit_tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ def setUp(self) -> None:
]

def test_classification_models(self):
model = models.get(Models.RESNET18, pretrained_weights="imagenet")

with tempfile.TemporaryDirectory() as tmp_dirname:
predictions = model.predict(self.images)
predictions.show()
predictions.save(output_folder=tmp_dirname)
for model_name in {Models.RESNET18, Models.EFFICIENTNET_B0, Models.MOBILENET_V2, Models.REGNETY200}:
model = models.get(model_name, pretrained_weights="imagenet")

predictions = model.predict(self.images)
predictions.show()
predictions.save(output_folder=tmp_dirname)

def test_pose_estimation_models(self):
model = models.get(Models.DEKR_W32_NO_DC, pretrained_weights="coco_pose")
Expand Down

0 comments on commit b6499b6

Please sign in to comment.