Skip to content

Commit

Permalink
Redirect logs based on env var
Browse files Browse the repository at this point in the history
  • Loading branch information
jnccd committed Sep 28, 2023
1 parent 1b558ed commit 60f63ef
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 56 deletions.
40 changes: 20 additions & 20 deletions documentation/source/qat_ptq_yolo_nas.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@ Now, let's get to it.

## Step 0: Installations and Dataset Setup

Follow the setup instructions for RF100:
Follow the [official instructions](https://github.com/roboflow/roboflow-100-benchmark?ref=roboflow-blog) to download Roboflow100:

To use this dataset, you **must** download the "coco" format, **NOT** the yolov5.

```
- Follow the official instructions to download Roboflow100: https://github.com/roboflow/roboflow-100-benchmark?ref=roboflow-blog
//!\\ To use this dataset, you must download the "coco" format, NOT the yolov5.
- Your dataset should look like this:
rf100
├── 4-fold-defect
│ ├─ train
│ │ ├─ 000000000001.jpg
│ │ ├─ ...
│ │ └─ _annotations.coco.json
│ ├─ valid
│ │ └─ ...
│ └─ test
│ └─ ...
├── abdomen-mri
│ └─ ...
└── ...
- Install CoCo API: https://github.com/pdollar/coco/tree/master/PythonAPI
- Your dataset should look like this:
rf100
├── 4-fold-defect
│ ├─ train
│ │ ├─ 000000000001.jpg
│ │ ├─ ...
│ │ └─ _annotations.coco.json
│ ├─ valid
│ │ └─ ...
│ └─ test
│ └─ ...
├── abdomen-mri
│ └─ ...
└── ...
- Install CoCo API: https://github.com/pdollar/coco/tree/master/PythonAPI
```

Install the latest version of SG:
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/common/auto_logging/auto_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _setup_default_logging(self, log_level: str = None) -> None:
# Therefore the log file will have the parent PID to being able to discriminate the logs corresponding to a single run.
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
self._setup_logging(
filename=os.path.expanduser(f"~/sg_logs/logs_{os.getppid()}_{timestamp}.log"),
filename=os.path.join(env_variables.SUPER_GRADIENTS_LOG_DIR, f"logs_{os.getppid()}_{timestamp}.log"),
copy_already_logged_messages=False,
filemode="w",
log_level=log_level,
Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/common/auto_logging/console_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from threading import Lock

from super_gradients.common.environment.ddp_utils import multi_process_safe, is_main_process
from super_gradients.common.environment.env_variables import env_variables


class BufferWriter:
Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(self):
@multi_process_safe
def _setup(self):
"""On instantiation, setup the default sink file."""
filename = Path.home() / "sg_logs" / "console.log"
filename = Path(env_variables.SUPER_GRADIENTS_LOG_DIR) / "console.log"
filename.parent.mkdir(exist_ok=True)
self.filename = str(filename)
os.makedirs(os.path.dirname(self.filename), exist_ok=True)
Expand Down
5 changes: 5 additions & 0 deletions src/super_gradients/common/environment/env_variables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Optional


Expand Down Expand Up @@ -45,5 +46,9 @@ def HYDRA_FULL_ERROR(self) -> Optional[str]:
def HYDRA_FULL_ERROR(self, value: str):
os.environ["HYDRA_FULL_ERROR"] = value

@property
def SUPER_GRADIENTS_LOG_DIR(self) -> str:
return os.getenv("SUPER_GRADIENTS_LOG_DIR", default=str(Path.home() / "sg_logs"))


env_variables = EnvironmentVariables()
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ def forward(self, input, target, smooth_dist=None):


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=CrossEntropyLoss)
@register_loss("LabelSmoothingCrossEntropyLoss")
class LabelSmoothingCrossEntropyLoss(CrossEntropyLoss):
...
12 changes: 7 additions & 5 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def instantiate_model(
net = architecture_cls(arch_params=arch_params)

if pretrained_weights:
# The logic is follows - first we initialize the preprocessing params using default hard-coded params
# If pretrained checkpoint contains preprocessing params, new params will be loaded and override the ones from
# this step in load_pretrained_weights_local/load_pretrained_weights
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

if is_remote and pretrained_weights_path:
load_pretrained_weights_local(net, model_name, pretrained_weights_path)
else:
Expand All @@ -162,11 +169,6 @@ def instantiate_model(
net.replace_head(new_num_classes=num_classes_new_head)
arch_params.num_classes = num_classes_new_head

# STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

_add_model_name_attribute(net, model_name)

return net
Expand Down
83 changes: 57 additions & 26 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,16 +1517,7 @@ def load_checkpoint_to_model(
message_model = "model" if not load_backbone else "model's backbone"
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)

if (isinstance(net, HasPredict)) and load_processing_params:
if "processing_params" not in checkpoint.keys():
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
try:
net.set_dataset_processing_params(**checkpoint["processing_params"])
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
_maybe_load_preprocessing_params(net, checkpoint)

if load_weights_only or load_backbone:
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
Expand All @@ -1549,10 +1540,12 @@ def __init__(self, desc):
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
"""
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None
"""
from super_gradients.common.object_names import Models

Expand All @@ -1569,19 +1562,19 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
"By downloading the pre-trained weight files you agree to comply with these terms."
)

unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
_load_weights(architecture, model, pretrained_state_dict)

# Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme
# which is a valid URI scheme for local files
# Supporting local files and file URI allows us modification of pretrained weights dics in unit tests
if url.startswith("file://") or os.path.exists(url):
pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu")
else:
unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)

def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")
_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
Expand All @@ -1598,3 +1591,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre

pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")


def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkpoint: Mapping[str, Tensor]) -> bool:
"""
Tries to load preprocessing params from the checkpoint to the model.
The function does not crash, and raises a warning if the loading fails.
:param model: Instance of nn.Module
:param checkpoint: Entire checkpoint dict (not state_dict with model weights)
:return: True if the loading was successful, False otherwise.
"""
model = unwrap_model(model)
checkpoint_has_preprocessing_params = "processing_params" in checkpoint.keys()
model_has_predict = isinstance(model, HasPredict)
logger.debug(
f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {checkpoint_has_preprocessing_params}. "
f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}"
)

if model_has_predict and checkpoint_has_preprocessing_params:
try:
model.set_dataset_processing_params(**checkpoint["processing_params"])
logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}")
return True
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
return False
29 changes: 26 additions & 3 deletions tests/unit_tests/pretrained_models_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os
import shutil
import tempfile
import unittest

import numpy as np
import torch

import super_gradients
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training import Trainer
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
from super_gradients.training.metrics import Accuracy
import os
import shutil
from super_gradients.training.pretrained_models import MODEL_URLS, PRETRAINED_NUM_CLASSES
from super_gradients.training.processing.processing import default_yolo_nas_coco_processing_params


class PretrainedModelsUnitTest(unittest.TestCase):
Expand All @@ -29,6 +36,22 @@ def test_pretrained_repvgg_a0_imagenet(self):
model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)

def test_pretrained_models_load_preprocessing_params(self):
"""
Test that checks whether preprocessing params from pretrained model load correctly.
"""
state = {"net": models.get(Models.YOLO_NAS_S, num_classes=80).state_dict(), "processing_params": default_yolo_nas_coco_processing_params()}
with tempfile.TemporaryDirectory() as td:
checkpoint_path = os.path.join(td, "yolo_nas_s_coco.pth")
torch.save(state, checkpoint_path)

MODEL_URLS[Models.YOLO_NAS_S + "_test"] = checkpoint_path
PRETRAINED_NUM_CLASSES["test"] = 80

model = models.get(Models.YOLO_NAS_S, pretrained_weights="test")
# .predict() would fail it model has no preprocessing params
self.assertIsNotNone(model.predict(np.zeros(shape=(512, 512, 3), dtype=np.uint8)))

def tearDown(self) -> None:
if os.path.exists("~/.cache/torch/hub/"):
shutil.rmtree("~/.cache/torch/hub/")
Expand Down

0 comments on commit 60f63ef

Please sign in to comment.