Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add return type to all public methods #36

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions tempora/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"Objects and routines pertaining to date and time (tempora)"

from __future__ import annotations

import contextlib
import datetime
import time
import re
import numbers
import functools
import contextlib
from numbers import Number
from typing import Union, Tuple, Iterable
from typing import cast
import numbers
import re
import time
from collections.abc import Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Tuple, Union, cast

import dateutil.parser
import dateutil.tz

if TYPE_CHECKING:
from typing_extensions import TypeAlias

# some useful constants
osc_per_year = 290_091_329_207_984_000
Expand Down Expand Up @@ -45,8 +48,8 @@ def _needs_year_help() -> bool:
return len(datetime.date(900, 1, 1).strftime('%Y')) != 4


AnyDatetime = Union[datetime.datetime, datetime.date, datetime.time]
StructDatetime = Union[Tuple[int, ...], time.struct_time]
AnyDatetime: TypeAlias = Union[datetime.datetime, datetime.date, datetime.time]
StructDatetime: TypeAlias = Union[Tuple[int, ...], time.struct_time]


def ensure_datetime(ob: AnyDatetime) -> datetime.datetime:
Expand All @@ -65,13 +68,14 @@ def ensure_datetime(ob: AnyDatetime) -> datetime.datetime:
return datetime.datetime.combine(date, time)


def infer_datetime(ob: Union[AnyDatetime, StructDatetime]) -> datetime.datetime:
def infer_datetime(ob: AnyDatetime | StructDatetime) -> datetime.datetime:
if isinstance(ob, (time.struct_time, tuple)):
# '"int" is not assignable to "tzinfo"', but we don't pass that many parameters
ob = datetime.datetime(*ob[:6]) # type: ignore[arg-type]
return ensure_datetime(ob)


def strftime(fmt: str, t: Union[AnyDatetime, tuple, time.struct_time]) -> str:
def strftime(fmt: str, t: AnyDatetime | tuple | time.struct_time) -> str:
"""
Portable strftime.

Expand Down Expand Up @@ -146,7 +150,7 @@ def doSubs(s):
return t.strftime(fmt)


def datetime_mod(dt, period, start=None):
def datetime_mod(dt: datetime.datetime, period, start=None) -> datetime.datetime:
"""
Find the time which is the specified date/time truncated to the time delta
relative to the start date/time.
Expand Down Expand Up @@ -190,7 +194,7 @@ def get_time_delta_microseconds(td):
return result


def datetime_round(dt, period, start=None):
def datetime_round(dt, period: datetime.timedelta, start=None) -> datetime.datetime:
"""
Find the nearest even period for the specified date/time.

Expand All @@ -210,7 +214,7 @@ def datetime_round(dt, period, start=None):
return result


def get_nearest_year_for_day(day):
def get_nearest_year_for_day(day) -> int:
"""
Returns the nearest year to now inferred from a Julian date.

Expand All @@ -235,7 +239,7 @@ def get_nearest_year_for_day(day):
return result


def gregorian_date(year, julian_day):
def gregorian_date(year, julian_day) -> datetime.date:
"""
Gregorian Date is defined as a year and a julian day (1-based
index into the days of the year).
Expand All @@ -248,7 +252,7 @@ def gregorian_date(year, julian_day):
return result


def get_period_seconds(period):
def get_period_seconds(period) -> int:
"""
return the number of seconds in the specified period

Expand Down Expand Up @@ -279,7 +283,7 @@ def get_period_seconds(period):
return result


def get_date_format_string(period):
def get_date_format_string(period) -> str:
"""
For a given period (e.g. 'month', 'day', or some numeric interval
such as 3600 (in secs)), return the format string that can be
Expand Down Expand Up @@ -307,7 +311,7 @@ def get_date_format_string(period):
if isinstance(period, str) and period.lower() == 'month':
return '%Y-%m'
file_period_secs = get_period_seconds(period)
format_pieces = ('%Y', '-%m-%d', ' %H', '-%M', '-%S')
format_pieces: Sequence[str] = ('%Y', '-%m-%d', ' %H', '-%M', '-%S')
seconds_per_second = 1
intervals = (
seconds_per_year,
Expand All @@ -321,7 +325,7 @@ def get_date_format_string(period):
return ''.join(format_pieces)


def calculate_prorated_values():
def calculate_prorated_values() -> None:
"""
>>> monkeypatch = getfixture('monkeypatch')
>>> import builtins
Expand All @@ -338,7 +342,7 @@ def calculate_prorated_values():
print(f"per {period}: {value}")


def _prorated_values(rate: str) -> Iterable[Tuple[str, Number]]:
def _prorated_values(rate: str) -> Iterator[tuple[str, float]]:
"""
Given a rate (a string in units per unit time), and return that same
rate for various time periods.
Expand All @@ -361,7 +365,7 @@ def _prorated_values(rate: str) -> Iterable[Tuple[str, Number]]:
yield period, period_value


def parse_timedelta(str):
def parse_timedelta(str) -> datetime.timedelta:
"""
Take a string representing a span of time and parse it to a time delta.
Accepts any string of comma-separated numbers each with a unit indicator.
Expand Down Expand Up @@ -455,19 +459,19 @@ def parse_timedelta(str):
return _parse_timedelta_nanos(str).resolve()


def _parse_timedelta_nanos(str):
def _parse_timedelta_nanos(str) -> _Saved_NS:
parts = re.finditer(r'(?P<value>[\d.:]+)\s?(?P<unit>[^\W\d_]+)?', str)
chk_parts = _check_unmatched(parts, str)
deltas = map(_parse_timedelta_part, chk_parts)
return sum(deltas, _Saved_NS())


def _check_unmatched(matches, text):
def _check_unmatched(matches: Iterable[re.Match[str]], text) -> Iterator[re.Match[str]]:
"""
Ensure no words appear in unmatched text.
"""

def check_unmatched(unmatched):
def check_unmatched(unmatched) -> None:
found = re.search(r'\w+', unmatched)
if found:
raise ValueError(f"Unexpected {found.group(0)!r}")
Expand Down Expand Up @@ -504,14 +508,14 @@ def check_unmatched(unmatched):
}


def _resolve_unit(raw_match):
def _resolve_unit(raw_match) -> str:
if raw_match is None:
return 'second'
text = raw_match.lower()
return _unit_lookup.get(text, text)


def _parse_timedelta_composite(raw_value, unit):
def _parse_timedelta_composite(raw_value, unit) -> _Saved_NS:
if unit != 'seconds':
raise ValueError("Cannot specify units with composite delta")
values = raw_value.split(':')
Expand All @@ -520,7 +524,7 @@ def _parse_timedelta_composite(raw_value, unit):
return _parse_timedelta_nanos(composed)


def _parse_timedelta_part(match):
def _parse_timedelta_part(match) -> _Saved_NS:
unit = _resolve_unit(match.group('unit'))
if not unit.endswith('s'):
unit += 's'
Expand Down Expand Up @@ -553,11 +557,11 @@ class _Saved_NS:
microseconds=1000,
)

def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
vars(self).update(kwargs)

@classmethod
def derive(cls, unit, value):
def derive(cls, unit, value) -> _Saved_NS:
if unit == 'nanoseconds':
return _Saved_NS(nanoseconds=value)

Expand Down Expand Up @@ -588,7 +592,7 @@ def __repr__(self):
return f'_Saved_NS(td={self.td!r}, nanoseconds={self.nanoseconds!r})'


def date_range(start=None, stop=None, step=None):
def date_range(start=None, stop=None, step=None) -> Iterator[datetime.datetime]:
"""
Much like the built-in function range, but works with dates

Expand Down Expand Up @@ -643,7 +647,7 @@ def date_range(start=None, stop=None, step=None):
)


def parse(*args, **kwargs):
def parse(*args, **kwargs) -> datetime.datetime:
"""
Parse the input using dateutil.parser.parse with friendly tz support.

Expand Down
53 changes: 31 additions & 22 deletions tempora/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,31 @@
datetime.datetime(...utc)
"""

import datetime
import numbers
from __future__ import annotations

import abc
import bisect
import datetime
import numbers
from typing import TYPE_CHECKING, Any

from .utc import fromtimestamp as from_timestamp
from .utc import now

from .utc import now, fromtimestamp as from_timestamp
if TYPE_CHECKING:
from typing_extensions import Self


class DelayedCommand(datetime.datetime):
"""
A command to be executed after some delay (seconds or timedelta).
"""

delay: datetime.timedelta = datetime.timedelta()
target: Any # Expected type depends on the scheduler used

@classmethod
def from_datetime(cls, other):
def from_datetime(cls, other) -> Self:
return cls(
other.year,
other.month,
Expand All @@ -50,7 +60,7 @@ def from_datetime(cls, other):
)

@classmethod
def after(cls, delay, target):
def after(cls, delay, target) -> Self:
if not isinstance(delay, datetime.timedelta):
delay = datetime.timedelta(seconds=delay)
due_time = now() + delay
Expand All @@ -71,7 +81,7 @@ def _from_timestamp(input):
return from_timestamp(input)

@classmethod
def at_time(cls, at, target):
def at_time(cls, at, target) -> Self:
"""
Construct a DelayedCommand to come due at `at`, where `at` may be
a datetime or timestamp.
Expand All @@ -82,7 +92,7 @@ def at_time(cls, at, target):
cmd.target = target
return cmd

def due(self):
def due(self) -> bool:
return now() >= self


Expand All @@ -92,19 +102,19 @@ class PeriodicCommand(DelayedCommand):
seconds.
"""

def _next_time(self):
def _next_time(self) -> Self:
"""
Add delay to self, localized
"""
return self + self.delay

def next(self):
def next(self) -> Self:
cmd = self.__class__.from_datetime(self._next_time())
cmd.delay = self.delay
cmd.target = self.target
return cmd

def __setattr__(self, key, value):
def __setattr__(self, key, value) -> None:
if key == 'delay' and not value > datetime.timedelta():
raise ValueError("A PeriodicCommand must have a positive, non-zero delay.")
super().__setattr__(key, value)
Expand All @@ -118,7 +128,7 @@ class PeriodicCommandFixedDelay(PeriodicCommand):
"""

@classmethod
def at_time(cls, at, delay, target):
def at_time(cls, at, delay, target) -> Self: # type: ignore[override] # jaraco/tempora#39
"""
>>> cmd = PeriodicCommandFixedDelay.at_time(0, 30, None)
>>> cmd.delay.total_seconds()
Expand All @@ -127,13 +137,13 @@ def at_time(cls, at, delay, target):
at = cls._from_timestamp(at)
cmd = cls.from_datetime(at)
if isinstance(delay, numbers.Number):
delay = datetime.timedelta(seconds=delay)
delay = datetime.timedelta(seconds=delay) # type: ignore[arg-type] # python/mypy#3186#issuecomment-1571512649
cmd.delay = delay
cmd.target = target
return cmd

@classmethod
def daily_at(cls, at, target):
def daily_at(cls, at, target) -> Self:
"""
Schedule a command to run at a specific time each day.

Expand All @@ -158,14 +168,13 @@ class Scheduler:
and dispatching them on schedule.
"""

def __init__(self):
self.queue = []
def __init__(self) -> None:
self.queue: list[DelayedCommand] = []

def add(self, command):
assert isinstance(command, DelayedCommand)
def add(self, command: DelayedCommand) -> None:
bisect.insort(self.queue, command)

def run_pending(self):
def run_pending(self) -> None:
while self.queue:
command = self.queue[0]
if not command.due():
Expand All @@ -176,7 +185,7 @@ def run_pending(self):
del self.queue[0]

@abc.abstractmethod
def run(self, command):
def run(self, command: DelayedCommand) -> None:
"""
Run the command
"""
Expand All @@ -187,7 +196,7 @@ class InvokeScheduler(Scheduler):
Command targets are functions to be invoked on schedule.
"""

def run(self, command):
def run(self, command: DelayedCommand) -> None:
command.target()


Expand All @@ -196,9 +205,9 @@ class CallbackScheduler(Scheduler):
Command targets are passed to a dispatch callable on schedule.
"""

def __init__(self, dispatch):
def __init__(self, dispatch) -> None:
super().__init__()
self.dispatch = dispatch

def run(self, command):
def run(self, command: DelayedCommand) -> None:
self.dispatch(command.target)
Loading
Loading