Skip to content

Commit

Permalink
general: adapt new marshalling to work with python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
karlicoss committed Sep 12, 2023
1 parent ec9ce78 commit 66b7250
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 35 deletions.
2 changes: 1 addition & 1 deletion benchmarks/20230912.org
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Running on @karlicoss desktop PC.
Running on @karlicoss desktop PC, =python3.10=

- serializing/deserializing here refers to converting object to json-ish python dictionary (not actual json string!)
- json dump/json load refers to converting the dict above to a json string and back
Expand Down
65 changes: 42 additions & 23 deletions src/cachew/marshall/cachew.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from collections import abc
from dataclasses import dataclass, is_dataclass
from datetime import datetime, timezone
import sys
import types
from typing import (
Any,
Dict,
List,
NamedTuple,
Optional,
Expand Down Expand Up @@ -44,7 +46,16 @@ def load(self, dct: Json) -> T:

# NOTE: using slots gives a small speedup (maybe 5%?)
# I suppose faster access to fields or something..
@dataclass(slots=True)

SLOTS: Dict[str, bool]
if sys.version_info[:2] >= (3, 10):
SLOTS = dict(slots=True)
else:
# not available :(
SLOTS = dict()


@dataclass(**SLOTS)
class Schema:
type: Any

Expand All @@ -57,7 +68,7 @@ def load(self, dct):
raise NotImplementedError


@dataclass(slots=True)
@dataclass(**SLOTS)
class SPrimitive(Schema):
def dump(self, obj):
# NOTE: returning here directly (instead of calling identity lambda) gives about 20% speedup
Expand All @@ -74,7 +85,7 @@ def load(self, dct):
# return prim(d)


@dataclass(slots=True)
@dataclass(**SLOTS)
class SDataclass(Schema):
# using list of tuples instead of dict gives about 5% speedup
fields: tuple[tuple[str, Schema], ...]
Expand All @@ -101,7 +112,7 @@ def load(self, dct):
# fmt: on


@dataclass(slots=True)
@dataclass(**SLOTS)
class SUnion(Schema):
# it's a bit faster to cache indixes here, gives about 15% speedup
args: tuple[tuple[int, Schema], ...]
Expand Down Expand Up @@ -133,7 +144,7 @@ def load(self, dct):
return s.load(val)


@dataclass(slots=True)
@dataclass(**SLOTS)
class SList(Schema):
arg: Schema

Expand All @@ -144,7 +155,7 @@ def load(self, dct):
return [self.arg.load(i) for i in dct]


@dataclass(slots=True)
@dataclass(**SLOTS)
class STuple(Schema):
args: tuple[Schema, ...]

Expand All @@ -155,7 +166,7 @@ def load(self, dct):
return tuple(a.load(i) for a, i in zip(self.args, dct))


@dataclass(slots=True)
@dataclass(**SLOTS)
class SSequence(Schema):
arg: Schema

Expand All @@ -166,7 +177,7 @@ def load(self, dct):
return tuple(self.arg.load(i) for i in dct)


@dataclass(slots=True)
@dataclass(**SLOTS)
class SDict(Schema):
ft: SPrimitive
tt: Schema
Expand Down Expand Up @@ -199,7 +210,7 @@ def _exc_helper(args):
yield a


@dataclass(slots=True)
@dataclass(**SLOTS)
class SException(Schema):
def dump(self, obj: Exception) -> Json:
return tuple(_exc_helper(obj.args))
Expand All @@ -208,7 +219,7 @@ def load(self, dct: Json):
return self.type(*dct)


@dataclass(slots=True)
@dataclass(**SLOTS)
class XDatetime(Schema):
def dump(self, obj: datetime) -> Json:
iso = obj.isoformat()
Expand Down Expand Up @@ -297,7 +308,14 @@ def build_schema(Type) -> Schema:
)

args = get_args(Type)
is_union = origin is Union or origin is types.UnionType

if sys.version_info[:2] >= (3, 10):
is_uniontype = origin is types.UnionType
else:
is_uniontype = False

is_union = origin is Union or is_uniontype

if is_union:
return SUnion(
type=Type,
Expand Down Expand Up @@ -389,27 +407,28 @@ def test_serialize_and_deserialize() -> None:

# unions
helper(1, Union[str, int])
helper('aaa', str | int)
if sys.version_info[:2] >= (3, 10):
helper('aaa', str | int)

# optionals
helper('aaa', Optional[str])
helper('aaa', str | None)
helper('aaa', str | None)
helper('aaa', Union[str, None])
helper(None , Union[str, None])

# lists
helper([1, 2, 3], list[int])
helper([1, 2, 3], List[int])
helper([1, 2, 3], List[int])
helper([1, 2, 3], Sequence[int], expected=(1, 2, 3))
helper((1, 2, 3), Sequence[int])
helper((1, 2, 3), Tuple[int, int, int])
helper((1, 2, 3), tuple[int, int, int])
helper((1, 2, 3), Tuple[int, int, int])

# dicts
helper({'a': 'aa', 'b': 'bb'}, dict[str, str])
helper({'a': None, 'b': 'bb'}, dict[str, Optional[str]])
helper({'a': 'aa', 'b': 'bb'}, Dict[str, str])
helper({'a': None, 'b': 'bb'}, Dict[str, Optional[str]])

# compounds of simple types
helper(['1', 2, '3'], list[str | int])
helper(['1', 2, '3'], List[Union[str, int]])

# TODO need to add test for equivalent dataclasses

Expand All @@ -431,12 +450,12 @@ class NT(NamedTuple):
@dataclass
class WithJson:
id: int
raw_data: dict[str, Any]
raw_data: Dict[str, Any]

# json-ish stuff
helper({}, dict[str, Any])
helper({}, Dict[str, Any])
helper(WithJson(id=123, raw_data=dict(payload='whatever', tags=['a', 'b', 'c'])), WithJson)
helper([], list[Any])
helper([], List[Any])

# exceptions
helper(RuntimeError('whatever!'), RuntimeError)
Expand All @@ -447,7 +466,7 @@ class WithJson:
Point(x=11, y=22),
RuntimeError('more stuff'),
RuntimeError(),
], list[RuntimeError | Point])
], List[Union[RuntimeError, Point]])
# fmt: on

# datetimes
Expand Down
5 changes: 4 additions & 1 deletion src/cachew/marshall/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from abc import abstractmethod
from typing import (
Any,
Dict,
Generic,
Tuple,
TypeVar,
Union,
)

Json = dict[str, Any] | tuple[Any, ...] | str | float | int | bool | None
Json = Union[Dict[str, Any], Tuple[Any, ...], str, float, int, bool, None]


T = TypeVar('T')
Expand Down
23 changes: 13 additions & 10 deletions src/cachew/tests/marshall.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import sqlite3
import sys
from typing import (
List,
Literal,
Union,
)

import orjson
Expand Down Expand Up @@ -48,9 +50,10 @@ def do_test(*, test_name: str, Type, factory, count: int, impl: Impl = 'cachew')
from typing import Union
import types

def is_union(type_) -> bool:
origin = get_origin(type_)
return origin is Union or origin is types.UnionType
# TODO use later
# def is_union(type_) -> bool:
# origin = get_origin(type_)
# return origin is Union or origin is types.UnionType

def union_structure_hook_factory(_):
def union_hook(data, type_):
Expand Down Expand Up @@ -95,12 +98,12 @@ def union_hook(data, type_):
with profile(test_name + ':baseline'), timer(f'building {count} objects of type {Type}'):
objects = list(factory(count=count))

jsons: list[Json] = [None for _ in range(count)]
jsons: List[Json] = [None for _ in range(count)]
with profile(test_name + ':serialize'), timer(f'serializing {count} objects of type {Type}'):
for i in range(count):
jsons[i] = to_json(objects[i])

strs: list[bytes] = [None for _ in range(count)] # type: ignore
strs: List[bytes] = [None for _ in range(count)] # type: ignore
with profile(test_name + ':json_dump'), timer(f'json dump {count} objects of type {Type}'):
for i in range(count):
# TODO any orjson options to speed up?
Expand All @@ -117,7 +120,7 @@ def union_hook(data, type_):
conn.executemany('INSERT INTO data (value) VALUES (?)', [(s,) for s in strs])
conn.close()

strs2: list[bytes] = [None for _ in range(count)] # type: ignore
strs2: List[bytes] = [None for _ in range(count)] # type: ignore
with profile(test_name + ':sqlite_load'), timer(f'sqlite load {count} objects of type {Type}'):
with sqlite3.connect(db) as conn:
i = 0
Expand All @@ -133,7 +136,7 @@ def union_hook(data, type_):
for s in strs:
fw.write(s + b'\n')

strs3: list[bytes] = [None for _ in range(count)] # type: ignore
strs3: List[bytes] = [None for _ in range(count)] # type: ignore
with profile(test_name + ':jsonl_load'), timer(f'jsonl load {count} objects of type {Type}'):
i = 0
with cache.open('rb') as fr:
Expand All @@ -144,7 +147,7 @@ def union_hook(data, type_):

assert strs2[:100] + strs2[-100:] == strs3[:100] + strs3[-100:] # just in case

jsons2: list[Json] = [None for _ in range(count)]
jsons2: List[Json] = [None for _ in range(count)]
with profile(test_name + ':json_load'), timer(f'json load {count} objects of type {Type}'):
for i in range(count):
# TODO any orjson options to speed up?
Expand Down Expand Up @@ -175,15 +178,15 @@ def test_union_str_dataclass(impl: Impl, count: int, gc_control, request) -> Non
pytest.skip('TODO need to adjust the handling of Union types..')

def factory(count: int):
objects: list[str | Name] = []
objects: List[Union[str, Name]] = []
for i in range(count):
if i % 2 == 0:
objects.append(str(i))
else:
objects.append(Name(first=f'first {i}', last=f'last {i}'))
return objects

do_test(test_name=request.node.name, Type=str | Name, factory=factory, count=count, impl=impl)
do_test(test_name=request.node.name, Type=Union[str, Name], factory=factory, count=count, impl=impl)


# OK, performance with calling this manually (not via pytest) is the same
Expand Down

0 comments on commit 66b7250

Please sign in to comment.