-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feature/damian/export_samples' of github.com:neuralmagi…
…c/sparseml into feature/damian/export_samples
- Loading branch information
Showing
11 changed files
with
394 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
24 changes: 24 additions & 0 deletions
24
src/sparseml/transformers/refactor_utils/export_samples.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from src.sparseml.transformers.sparsification.trainer import Trainer | ||
import logging | ||
from transformers import AutoTokenizer | ||
__all__ = ["export_samples"] | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
def export_samples(trainer: Trainer, tokenizer: AutoTokenizer, num_samples: int, real_samples = False): | ||
_LOGGER.info(f"Exporting {num_samples} sample inputs/outputs") | ||
if real_samples: | ||
try: | ||
trainer.get_eval_dataloader() | ||
except: | ||
raise ValueError("The trainer does not contain evaluation dataloader. " | ||
"Either set `real_samples = False` to generate fake samples " | ||
"or initialize the trainer with `eval_dataset` argument.") | ||
|
||
trainer.save_sample_inputs_outputs( | ||
num_samples_to_export=num_samples, | ||
tokenizer=tokenizer, | ||
) | ||
_LOGGER.info(f"{num_samples} sample inputs/outputs exported") | ||
|
||
|
131 changes: 131 additions & 0 deletions
131
src/sparseml/transformers/refactor_utils/initialize_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Functionality for initializing a transformer model from a given path | ||
""" | ||
# TODO: Add docstrings | ||
|
||
import logging | ||
import math | ||
import os | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Any, Union, Optional | ||
|
||
from transformers import AutoConfig, AutoTokenizer, TrainingArguments | ||
|
||
from sparseml.transformers.sparsification import Trainer | ||
from src.sparseml.transformers.utils.model import TransformerModelsRegistry | ||
|
||
|
||
__all__ = ["initialize_transformer_model"] | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class ForceCPUTrainingArguments(TrainingArguments): | ||
@property | ||
def place_model_on_device(self): | ||
# TODO: Observe how this setting influences memory consumption | ||
# The property governs whether or not to automatically place | ||
# the model on the device. Setting to False ensures that the | ||
# model remains in CPU during ONNX export | ||
return False | ||
|
||
|
||
def initialize_transformer_model( | ||
model_path: Union[str, Path], | ||
sequence_length: int, | ||
task: str, | ||
trust_remote_code: bool = False, | ||
**config_args, | ||
): | ||
|
||
config = initialize_config(model_path, trust_remote_code, **config_args) | ||
tokenizer = initialize_tokenizer(model_path, task, sequence_length) | ||
model = TransformerModelsRegistry.load_from_registry(task)( | ||
**dict( | ||
model_name_or_path=model_path, | ||
model_type="model", | ||
config=config, | ||
trust_remote_code=trust_remote_code, | ||
) | ||
) | ||
model.train() | ||
trainer = initialize_trainer(model, model_path) | ||
model.eval() | ||
|
||
_LOGGER.info(f"Loaded model, trainer config, and tokenizer from {model_path}") | ||
return model, trainer, config, tokenizer | ||
|
||
|
||
def initialize_trainer(model: Any, model_path: Union[str, Path]) -> Trainer: | ||
training_args = TrainingArguments(output_dir=os.path.dirname(model_path)) | ||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
model_state_path=model_path, | ||
# TODO: Do we need eval_dataset? | ||
# eval_dataset=eval_dataset, | ||
recipe=None, | ||
recipe_args=None, | ||
teacher=None, | ||
) | ||
applied = trainer.apply_manager(epoch=math.inf, checkpoint=None) | ||
|
||
if not applied: | ||
_LOGGER.warning( | ||
f"No recipes were applied for {model_path}, " | ||
"check to make sure recipe(s) are stored in the model_path" | ||
) | ||
else: | ||
trainer.finalize_manager() | ||
num_stages = 0 | ||
if trainer.manager: | ||
num_stages += trainer.manager.num_stages() | ||
if trainer.arch_manager: | ||
num_stages += trainer.arch_manager.num_stages() | ||
|
||
msg = ( | ||
"an unstaged recipe" | ||
if num_stages == 1 | ||
else f"a staged recipe with {num_stages} stages" | ||
) | ||
_LOGGER.info(f"Applied {msg} to the model at {model_path}") | ||
|
||
return trainer | ||
|
||
|
||
def initialize_config( | ||
model_path: Union[str, Path], trust_remote_code: bool = False, **config_args | ||
) -> AutoConfig: | ||
config = AutoConfig.from_pretrained( | ||
model_path, | ||
trust_remote_code=trust_remote_code, | ||
**config_args, | ||
) | ||
return config | ||
|
||
|
||
def initialize_tokenizer( | ||
model_path: Union[str, Path], task: str, sequence_length: Optional[int] = None, | ||
) -> AutoTokenizer: | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_path, model_max_length=sequence_length | ||
) | ||
if task == "text-generation": | ||
tokenizer.pad_token = tokenizer.eos_token | ||
return tokenizer |
52 changes: 52 additions & 0 deletions
52
src/sparseml/transformers/refactor_utils/initialize_task_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from sparsezoo.utils.registry import RegistryMixin | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
from transformers import PreTrainedTokenizerBase | ||
from sparseml.transformers.masked_language_modeling import DataTrainingArguments | ||
|
||
TEXT_CLASSIFICATION_TASKS = ["sequence-classification", "glue", "sentiment-analysis", "text-classification"] | ||
|
||
class TaskDatasetRegistry(RegistryMixin): | ||
@classmethod | ||
def load_from_registry(cls, name: str) -> Callable[..., Any]: | ||
return cls.get_value_from_registry(name=name) | ||
|
||
@TaskDatasetRegistry.register(name=["masked-language-modeling", "mlm"]) | ||
def dataset_function(): | ||
from sparseml.transformers.masked_language_modeling import ( | ||
get_tokenized_mlm_dataset, | ||
) | ||
return get_tokenized_mlm_dataset | ||
|
||
@TaskDatasetRegistry.register(name=["question-answering","qa"]) | ||
def dataset_function(): | ||
from sparseml.transformers.question_answering import ( | ||
get_tokenized_qa_dataset, | ||
) | ||
|
||
return get_tokenized_qa_dataset | ||
|
||
@TaskDatasetRegistry.register(name=["token-classification", "ner"]) | ||
def dataset_function(): | ||
from sparseml.transformers.token_classification import ( | ||
get_tokenized_token_classification_dataset, | ||
) | ||
|
||
return get_tokenized_token_classification_dataset | ||
|
||
@TaskDatasetRegistry.register(name=TEXT_CLASSIFICATION_TASKS) | ||
def dataset_function(): | ||
from sparseml.transformers.text_classification import ( | ||
get_tokenized_text_classification_dataset, | ||
) | ||
|
||
return get_tokenized_text_classification_dataset | ||
|
||
def initialize_task_dataset(task:str, tokenizer: PreTrainedTokenizerBase, model: Optional[Any]=None, config: Optional[Any] = None, data_args: Dict[str, Any]= {}): | ||
tokenized_task_dataset = TaskDatasetRegistry.load_from_registry(task)() | ||
if task in TEXT_CLASSIFICATION_TASKS: | ||
return tokenized_task_dataset(tokenizer=tokenizer, model=model, config=config, data_args=DataTrainingArguments(**data_args)) | ||
|
||
return tokenized_task_dataset(tokenizer=tokenizer, data_args=DataTrainingArguments(**data_args)) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Oops, something went wrong.