Skip to content

Commit

Permalink
[Export refactor] final manual testing fixes (#1948)
Browse files Browse the repository at this point in the history
* [Export refactor] final manual testing fixes

* review
  • Loading branch information
bfineran authored Jan 10, 2024
1 parent c3c90a4 commit 57a4dd0
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .export import *
3 changes: 1 addition & 2 deletions src/sparseml/export/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import logging
import os
import shutil
from collections import OrderedDict
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Union

from sparseml.exporters import ExportTargets

Expand Down
38 changes: 28 additions & 10 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,19 @@ def resolve_integration(
will attempt to infer it from the source_path.
:return: The name of the integration to use for exporting the model.
"""
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)
from sparseml.transformers.utils.helpers import is_transformer_model
try:
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)
except ImportError:
# unable to import integration, always return False
is_image_classification_model = _null_is_model

try:
from sparseml.transformers.utils.helpers import is_transformer_model
except ImportError:
# unable to import integration, always return False
is_transformer_model = _null_is_model

if (
integration == Integrations.image_classification.value
Expand All @@ -63,7 +72,6 @@ def resolve_integration(
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401

return Integrations.image_classification.value

elif integration == Integrations.transformers.value or is_transformer_model(
source_path
):
Expand All @@ -80,6 +88,12 @@ def resolve_integration(
)


def _null_is_model(*args, **kwargs):
# convenience function to always return False for an integration
# to be used if that integration is not importable
return False


class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""
Registry that maps names to helper functions
Expand All @@ -88,7 +102,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""

create_model: Callable[
[Union[str, Path], ...],
[Union[str, Path]],
Tuple[
"torch.nn.Module", # noqa F821
Optional[Dict[str, Any]],
Expand All @@ -102,13 +116,13 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"- (optionally) loaded_model_kwargs "
"(any relevant objects created along with the model)"
)
create_dummy_input: Callable[..., "torch.Tensor"] = Field( # noqa F821
create_dummy_input: Callable[[Any], "torch.Tensor"] = Field( # noqa F821
description="A function that takes: "
"- appropriate arguments "
"and returns: "
"- a dummy input for the model (a torch.Tensor) "
)
export: Callable[..., str] = Field(
export: Callable[[Any], str] = Field(
description="A function that takes: "
" - a (sparse) PyTorch model "
" - sample input data "
Expand All @@ -120,15 +134,19 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"and returns the path to the exported model",
default=export_model,
)
apply_optimizations: Optional[Callable[..., None]] = Field(
apply_optimizations: Optional[Callable[[Any], None]] = Field(
description="A function that takes:"
" - path to the exported model"
" - names of the optimizations to apply"
" and applies the optimizations to the model",
)

create_data_samples: Callable[
Tuple[Optional["torch.nn.Module"], int, Optional[Dict[str, Any]]], # noqa F821
[
Tuple[
Optional["torch.nn.Module"], int, Optional[Dict[str, Any]] # noqa: F821
]
],
Tuple[
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
Expand Down
11 changes: 8 additions & 3 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,14 @@ def _add_tensorboard_logger_if_available(self):
self.args, log_dir=self.args.logging_dir
)

self.logger_manager.add_logger(
TensorBoardLogger(writer=tensorboard_callback.tb_writer)
)
try:
self.logger_manager.add_logger(
TensorBoardLogger(writer=tensorboard_callback.tb_writer)
)
except (ImportError, ModuleNotFoundError):
_LOGGER.info(
f"Unable to import tensorboard - running without tensorboard logging"
)

def _get_fake_dataloader(
self,
Expand Down

0 comments on commit 57a4dd0

Please sign in to comment.