From 8bf0e09d2e7a8520ce478a84a8a943494173807a Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 7 Dec 2023 17:56:29 -0800 Subject: [PATCH 1/3] time to clean up time parsing --- composer/callbacks/activation_monitor.py | 7 +- composer/callbacks/image_visualizer.py | 7 +- composer/core/time.py | 88 ++++++++++++++++-------- composer/loggers/console_logger.py | 6 +- composer/loggers/slack_logger.py | 5 +- composer/utils/misc.py | 6 +- tests/test_time.py | 63 +++++++++++++++++ 7 files changed, 129 insertions(+), 53 deletions(-) diff --git a/composer/callbacks/activation_monitor.py b/composer/callbacks/activation_monitor.py index 8e11976b32..b30f65aea4 100644 --- a/composer/callbacks/activation_monitor.py +++ b/composer/callbacks/activation_monitor.py @@ -98,12 +98,7 @@ def __init__(self, self.handles = [] # Check that the interval timestring is parsable and convert into time object - if isinstance(interval, int): - self.interval = Time(interval, TimeUnit.BATCH) - elif isinstance(interval, str): - self.interval = Time.from_timestring(interval) - elif isinstance(interval, Time): - self.interval = interval + self.interval = Time.from_input(interval, TimeUnit.BATCH) if self.interval.unit == TimeUnit.BATCH and self.interval < Time.from_timestring('10ba'): warnings.warn(f'Currently the ActivationMonitor`s interval is set to {self.interval} ' diff --git a/composer/callbacks/image_visualizer.py b/composer/callbacks/image_visualizer.py index 7904bc3000..c1a9379665 100644 --- a/composer/callbacks/image_visualizer.py +++ b/composer/callbacks/image_visualizer.py @@ -49,7 +49,7 @@ class ImageVisualizer(Callback): This callback only works with wandb logging for now. Args: - interval (str | Time, optional): Time string specifying how often to log train images. For example, ``interval='1ep'`` + interval (int | str | Time, optional): Time string specifying how often to log train images. For example, ``interval='1ep'`` means images are logged once every epoch, while ``interval='100ba'`` means images are logged once every 100 batches. Eval images are logged once at the start of each eval. Default: ``"100ba"``. mode (str, optional): How to log the image labels. Valid values are ``"input"`` (input only) @@ -86,10 +86,7 @@ def __init__(self, raise ValueError(f'Invalid mode: {mode}') # Check that the interval timestring is parsable and convert into time object - if isinstance(interval, int): - self.interval = Time(interval, TimeUnit.BATCH) - if isinstance(interval, str): - self.interval = Time.from_timestring(interval) + self.interval = Time.from_input(interval, TimeUnit.BATCH) # Verify that the interval has supported units if self.interval.unit not in [TimeUnit.BATCH, TimeUnit.EPOCH]: diff --git a/composer/core/time.py b/composer/core/time.py index 5e1c5ee4e7..6a024ab6d4 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -19,6 +19,7 @@ import datetime import re +import warnings from typing import Any, Dict, Generic, Optional, TypeVar, Union, cast from composer.core.serializable import Serializable @@ -223,20 +224,12 @@ def to_timestring(self): """ return str(self) - def _parse(self, other: object) -> Time: + def _parse(self, other: Union[int, float, Time, str]) -> Time: # parse ``other`` into a Time object - if isinstance(other, Time): - return other - if isinstance(other, int): - return Time(other, self.unit) - if isinstance(other, str): - other_parsed = Time.from_timestring(other) - return other_parsed - - raise TypeError(f'Cannot convert type {other} to {self.__class__.__name__}') + return Time.from_input(other, self.unit) def _cmp(self, other: Union[int, float, Time, str]) -> int: - # When doing comparisions, and other is an integer (or float), we can safely infer + # When doing comparisons, and other is an integer (or float), we can safely infer # the unit from self.unit # E.g. calls like this should be allowed: if batch < 42: do_something() # This eliminates the need to call .value everywhere @@ -302,7 +295,7 @@ def __int__(self): def __float__(self): return float(self.value) - def __truediv__(self, other: object) -> Time[float]: + def __truediv__(self, other: Union[int, float, Time, str]) -> Time[float]: if isinstance(other, (float, int)): return Time(type(self.value)(self.value / other), self.unit) other = self._parse(other) @@ -310,7 +303,19 @@ def __truediv__(self, other: object) -> Time[float]: raise RuntimeError(f'Cannot divide {self} by {other} since they have different units.') return Time(self.value / other.value, TimeUnit.DURATION) - def __mul__(self, other: object): + def __floordiv__(self, other: Union[int, float, Time, str]) -> Time[int]: + other = self._parse(other) + if self.unit != other.unit: + raise RuntimeError(f'Cannot divide {self} by {other} since they have different units.') + return Time(self.value // other.value, TimeUnit.DURATION) + + def __mod__(self, other: Union[int, float, Time, str]) -> Time[TValue]: + other = self._parse(other) + if self.unit != other.unit: + raise RuntimeError(f'Cannot take mod of {self} by {other} since they have different units.') + return Time(self.value % other.value, self.unit) + + def __mul__(self, other: Union[int, float, Time, str]): if isinstance(other, (float, int)): # Scale by the value. return Time(type(self.value)(self.value * other), self.unit) @@ -321,12 +326,43 @@ def __mul__(self, other: object): real_type = float if real_unit == TimeUnit.DURATION else int return Time(real_type(self.value * other.value), real_unit) - def __rmul__(self, other: object): + def __rmul__(self, other: Union[int, float, Time, str]): return self * other def __hash__(self): return hash((self.value, self.unit)) + @classmethod + def from_input(cls, + i: Union[str, int, float, 'Time'], + default_int_unit: Optional[Union[TimeUnit, str]] = None) -> Time: + """Parse a time input into a :class:`Time` instance. + + Args: + i (str | int | Time): The time input. + default_int_unit (TimeUnit, optional): The default unit to use if ``i`` is an integer + + >>> Time.from_input("5ep") + Time(5, TimeUnit.EPOCH) + >>> Time.from_input(5, TimeUnit.EPOCH) + Time(5, TimeUnit.EPOCH) + + Returns: + Time: An instance of :class:`Time`. + """ + if isinstance(i, Time): + return i + + if isinstance(i, str): + return Time.from_timestring(i) + + if isinstance(i, int) or isinstance(i, float): + if default_int_unit is None: + raise RuntimeError('default_int_unit must be specified when constructing Time from an integer.') + return Time(i, default_int_unit) + + raise RuntimeError(f'Cannot convert type {i} to {cls.__name__}') + @classmethod def from_timestring(cls, timestring: str) -> Time: """Parse a time string into a :class:`Time` instance. @@ -563,7 +599,7 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]: return self.token raise ValueError(f'Invalid unit: {unit}') - def _parse(self, other: object) -> Time: + def _parse(self, other: Union[int, float, Time, str]) -> Time: # parse ``other`` into a Time object if isinstance(other, Time): return other @@ -573,7 +609,7 @@ def _parse(self, other: object) -> Time: raise TypeError(f'Cannot convert type {other} to {self.__class__.__name__}') - def __eq__(self, other: object): + def __eq__(self, other: Union[int, float, Time, str]): if not isinstance(other, (Time, Timestamp, str)): return NotImplemented if isinstance(other, Timestamp): @@ -582,7 +618,7 @@ def __eq__(self, other: object): self_counter = self.get(other.unit) return self_counter == other - def __ne__(self, other: object): + def __ne__(self, other: Union[int, float, Time, str]): if not isinstance(other, (Time, Timestamp, str)): return NotImplemented if isinstance(other, Timestamp): @@ -591,28 +627,28 @@ def __ne__(self, other: object): self_counter = self.get(other.unit) return self_counter != other - def __lt__(self, other: object): + def __lt__(self, other: Union[int, float, Time, str]): if not isinstance(other, (Time, str)): return NotImplemented other = self._parse(other) self_counter = self.get(other.unit) return self_counter < other - def __le__(self, other: object): + def __le__(self, other: Union[int, float, Time, str]): if not isinstance(other, (Time, str)): return NotImplemented other = self._parse(other) self_counter = self.get(other.unit) return self_counter <= other - def __gt__(self, other: object): + def __gt__(self, other: Union[int, float, Time, str]): if not isinstance(other, (Time, str)): return NotImplemented other = self._parse(other) self_counter = self.get(other.unit) return self_counter > other - def __ge__(self, other: object): + def __ge__(self, other: Union[int, float, Time, str]): if not isinstance(other, (Time, str)): return NotImplemented other = self._parse(other) @@ -783,10 +819,6 @@ def ensure_time(maybe_time: Union[Time, str, int], int_unit: Union[TimeUnit, str Returns: Time: An instance of :class:`.Time`. """ - if isinstance(maybe_time, str): - return Time.from_timestring(maybe_time) - if isinstance(maybe_time, int): - return Time(maybe_time, int_unit) - if isinstance(maybe_time, Time): - return maybe_time - raise TypeError(f'Unsupported type for ensure_time: {type(maybe_time)}') + warnings.warn('ensure_time is deprecated. Use Time.from_input instead.', DeprecationWarning) + + return Time.from_input(maybe_time, int_unit) diff --git a/composer/loggers/console_logger.py b/composer/loggers/console_logger.py index e25ac4f268..df97cdff04 100644 --- a/composer/loggers/console_logger.py +++ b/composer/loggers/console_logger.py @@ -44,11 +44,7 @@ def __init__(self, stream: Union[str, TextIO] = sys.stderr, log_traces: bool = False) -> None: - if isinstance(log_interval, int): - log_interval = Time(log_interval, TimeUnit.EPOCH) - if isinstance(log_interval, str): - log_interval = Time.from_timestring(log_interval) - + log_interval = Time.from_input(log_interval, TimeUnit.EPOCH) self.last_logged_batch = 0 if log_interval.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH): diff --git a/composer/loggers/slack_logger.py b/composer/loggers/slack_logger.py index 15126d8ffd..1962f5f72f 100644 --- a/composer/loggers/slack_logger.py +++ b/composer/loggers/slack_logger.py @@ -102,10 +102,7 @@ def __init__( # Create a regex of all keys to include self.regex_all_keys = '(' + ')|('.join(include_keys) + ')' - if isinstance(log_interval, int): - self.log_interval = Time(log_interval, TimeUnit.EPOCH) - if isinstance(log_interval, str): - self.log_interval = Time.from_timestring(log_interval) + self.log_interval = Time.from_input(log_interval, TimeUnit.EPOCH) if self.log_interval.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH): raise ValueError('The `slack logger log_interval` argument must have units of EPOCH or BATCH.') diff --git a/composer/utils/misc.py b/composer/utils/misc.py index 9805ab453e..76573f8901 100644 --- a/composer/utils/misc.py +++ b/composer/utils/misc.py @@ -52,11 +52,7 @@ def create_interval_scheduler(interval: Union[str, int, 'Time'], if final_events is None: final_events = {Event.BATCH_CHECKPOINT, Event.EPOCH_CHECKPOINT} - if isinstance(interval, str): - interval = Time.from_timestring(interval) - if isinstance(interval, int): - interval = Time(interval, TimeUnit.EPOCH) - + interval = Time.from_input(interval, TimeUnit.EPOCH) if interval.unit == TimeUnit.EPOCH: interval_event = Event.EPOCH_CHECKPOINT if checkpoint_events else Event.EPOCH_END elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}: diff --git a/tests/test_time.py b/tests/test_time.py index e7b03b8923..86bca1fc68 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -54,6 +54,69 @@ def test_time_math(): assert t4 * 2 == Time.from_timestring('1dur') assert t1 / t2 == t4 assert t2 / 2 == t1 + assert t3 // 2 == t1 + assert t3 // t2 == t1 + assert t3 % t3 == Time.from_timestring('0ep') + assert t3 % t2 == t1 + + +def test_invalid_math(): + t1 = Time.from_timestring('1ep') + t2 = Time.from_timestring('1ba') + + with pytest.raises(RuntimeError): + _ = t1 > t2 + + with pytest.raises(RuntimeError): + _ = t1 < t2 + + with pytest.raises(RuntimeError): + _ = t1 >= t2 + + with pytest.raises(RuntimeError): + _ = t1 <= t2 + + with pytest.raises(RuntimeError): + _ = t1 == t2 + + with pytest.raises(RuntimeError): + _ = t1 != t2 + + with pytest.raises(RuntimeError): + _ = t1 + t2 + + with pytest.raises(RuntimeError): + _ = t1 - t2 + + with pytest.raises(RuntimeError): + _ = t1 / t2 + + with pytest.raises(RuntimeError): + _ = t1 // t2 + + with pytest.raises(RuntimeError): + _ = t1 % t2 + + with pytest.raises(RuntimeError): + _ = t1 * t2 + + +def test_time_from_input(): + expected = Time(1, TimeUnit.EPOCH) + + assert Time.from_input(expected) == expected + assert Time.from_input('1ep') == expected + assert Time.from_input(1, TimeUnit.EPOCH) == expected + assert Time.from_input(1, 'ep') == expected + + with pytest.raises(RuntimeError): + Time.from_input(None) # type: ignore + + with pytest.raises(RuntimeError): + Time.from_input([123]) # type: ignore + + with pytest.raises(RuntimeError): + Time.from_input(1) def test_time_repr(): From 402d8eca11aeb76dc540d1b9e0ad71b0ce643351 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Thu, 7 Dec 2023 18:03:45 -0800 Subject: [PATCH 2/3] fix type error --- composer/core/time.py | 17 +++++++---------- composer/trainer/trainer.py | 3 +++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/composer/core/time.py b/composer/core/time.py index 6a024ab6d4..c9fce35d0c 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -19,7 +19,6 @@ import datetime import re -import warnings from typing import Any, Dict, Generic, Optional, TypeVar, Union, cast from composer.core.serializable import Serializable @@ -429,39 +428,39 @@ def __init__( epoch_wct: Optional[datetime.timedelta] = None, batch_wct: Optional[datetime.timedelta] = None, ): - epoch = ensure_time(epoch, TimeUnit.EPOCH) + epoch = Time.from_input(epoch, TimeUnit.EPOCH) if epoch.unit != TimeUnit.EPOCH: raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.') self._epoch = epoch - batch = ensure_time(batch, TimeUnit.BATCH) + batch = Time.from_input(batch, TimeUnit.BATCH) if batch.unit != TimeUnit.BATCH: raise ValueError(f'The `batch` argument has units of {batch.unit}; not {TimeUnit.BATCH}.') self._batch = batch - sample = ensure_time(sample, TimeUnit.SAMPLE) + sample = Time.from_input(sample, TimeUnit.SAMPLE) if sample.unit != TimeUnit.SAMPLE: raise ValueError(f'The `sample` argument has units of {sample.unit}; not {TimeUnit.SAMPLE}.') self._sample = sample - token = ensure_time(token, TimeUnit.TOKEN) + token = Time.from_input(token, TimeUnit.TOKEN) if token.unit != TimeUnit.TOKEN: raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.') self._token = token - batch_in_epoch = ensure_time(batch_in_epoch, TimeUnit.BATCH) + batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH) if batch_in_epoch.unit != TimeUnit.BATCH: raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; ' f'not {TimeUnit.BATCH}.')) self._batch_in_epoch = batch_in_epoch - sample_in_epoch = ensure_time(sample_in_epoch, TimeUnit.SAMPLE) + sample_in_epoch = Time.from_input(sample_in_epoch, TimeUnit.SAMPLE) if sample_in_epoch.unit != TimeUnit.SAMPLE: raise ValueError((f'The `sample_in_epoch` argument has units of {sample_in_epoch.unit}; ' f'not {TimeUnit.SAMPLE}.')) self._sample_in_epoch = sample_in_epoch - token_in_epoch = ensure_time(token_in_epoch, TimeUnit.TOKEN) + token_in_epoch = Time.from_input(token_in_epoch, TimeUnit.TOKEN) if token_in_epoch.unit != TimeUnit.TOKEN: raise ValueError((f'The `token_in_epoch` argument has units of {token_in_epoch.unit}; ' f'not {TimeUnit.TOKEN}.')) @@ -819,6 +818,4 @@ def ensure_time(maybe_time: Union[Time, str, int], int_unit: Union[TimeUnit, str Returns: Time: An instance of :class:`.Time`. """ - warnings.warn('ensure_time is deprecated. Use Time.from_input instead.', DeprecationWarning) - return Time.from_input(maybe_time, int_unit) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 8e17a1b87b..8fb99d45aa 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2018,6 +2018,9 @@ def _train_loop(self) -> None: finished_epoch_early = False last_wct = datetime.datetime.now() + if self.state.max_duration is None: + raise RuntimeError('max_duration must be specified when calling Trainer.fit()') + while self.state.timestamp < self.state.max_duration: try: if int(self.state.timestamp.batch_in_epoch) == 0: From 0f7adc02f07c8d010485a8eaed33fcdfc80efbc2 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Fri, 8 Dec 2023 10:43:41 -0800 Subject: [PATCH 3/3] updates --- composer/core/time.py | 6 ------ composer/trainer/trainer.py | 4 +++- tests/test_time.py | 5 ----- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/composer/core/time.py b/composer/core/time.py index c9fce35d0c..ab2d6a60ee 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -302,12 +302,6 @@ def __truediv__(self, other: Union[int, float, Time, str]) -> Time[float]: raise RuntimeError(f'Cannot divide {self} by {other} since they have different units.') return Time(self.value / other.value, TimeUnit.DURATION) - def __floordiv__(self, other: Union[int, float, Time, str]) -> Time[int]: - other = self._parse(other) - if self.unit != other.unit: - raise RuntimeError(f'Cannot divide {self} by {other} since they have different units.') - return Time(self.value // other.value, TimeUnit.DURATION) - def __mod__(self, other: Union[int, float, Time, str]) -> Time[TValue]: other = self._parse(other) if self.unit != other.unit: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 8fb99d45aa..062c1cfb36 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2019,7 +2019,9 @@ def _train_loop(self) -> None: last_wct = datetime.datetime.now() if self.state.max_duration is None: - raise RuntimeError('max_duration must be specified when calling Trainer.fit()') + # This is essentially just a type check, as max_duration should always be + # asserted to be not None when Trainer.fit() is called + raise RuntimeError('max_duration must be specified when initializing the Trainer') while self.state.timestamp < self.state.max_duration: try: diff --git a/tests/test_time.py b/tests/test_time.py index 86bca1fc68..58f1cf9747 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -54,8 +54,6 @@ def test_time_math(): assert t4 * 2 == Time.from_timestring('1dur') assert t1 / t2 == t4 assert t2 / 2 == t1 - assert t3 // 2 == t1 - assert t3 // t2 == t1 assert t3 % t3 == Time.from_timestring('0ep') assert t3 % t2 == t1 @@ -91,9 +89,6 @@ def test_invalid_math(): with pytest.raises(RuntimeError): _ = t1 / t2 - with pytest.raises(RuntimeError): - _ = t1 // t2 - with pytest.raises(RuntimeError): _ = t1 % t2