Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cached dataset wrapper #148

Merged
merged 69 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
62325c6
init commit
wangraying Jul 28, 2021
2c640ba
add
wangraying Jul 28, 2021
9dda8cc
update
wangraying Jul 28, 2021
d41667e
update redis config
wangraying Jul 28, 2021
5d130b7
.
wangraying Jul 29, 2021
b6f36d2
fix
wangraying Jul 29, 2021
ec998cc
batch writes
wangraying Jul 31, 2021
6c6ab3d
add cluster
wangraying Aug 4, 2021
241b21e
Update bagua/torch_api/contrib/utils/redis_store.py
wangraying Aug 4, 2021
14d6d17
add
wangraying Aug 11, 2021
4b7a7d6
Update bagua/torch_api/contrib/cache_dataset.py
wangraying Aug 11, 2021
0bb1d15
Update bagua/torch_api/contrib/cache_dataset.py
wangraying Aug 11, 2021
be8dfa8
Apply suggestions from code review
wangraying Aug 11, 2021
a1a9c78
Update bagua/torch_api/contrib/cache_dataset.py
wangraying Aug 11, 2021
e864347
format
wangraying Aug 12, 2021
7ab8675
update doc
wangraying Aug 12, 2021
53ae811
.
wangraying Aug 12, 2021
b53a186
.
wangraying Aug 12, 2021
bb4acad
refine docs
wangraying Aug 12, 2021
8811888
add
wangraying Aug 12, 2021
d90c098
update doc
wangraying Aug 12, 2021
ebdd520
add import error
wangraying Aug 12, 2021
505f55b
Apply suggestions from code review
wangraying Aug 12, 2021
50c4ea7
add package check
wangraying Aug 12, 2021
401e3d4
update
wangraying Aug 12, 2021
471bb57
update test
wangraying Aug 12, 2021
333616f
Merge branch 'master' into cached-dataset
wangraying Aug 12, 2021
eaad845
Merge branch 'cached-dataset' of https://github.com/BaguaSys/bagua in…
wangraying Aug 12, 2021
a95c260
fix type
wangraying Aug 12, 2021
1c2d4c9
add
wangraying Aug 12, 2021
179d939
.
wangraying Aug 12, 2021
5ddef17
pytype
wangraying Aug 12, 2021
62a5753
.
wangraying Aug 12, 2021
06ba1f5
auto cleanup
wangraying Aug 12, 2021
75c9f2c
pytype
wangraying Aug 12, 2021
eb8fde5
.
wangraying Aug 12, 2021
45bf719
doc doc
wangraying Aug 12, 2021
89aed0a
update hash
wangraying Aug 13, 2021
351bf23
Apply suggestions from code review
wangraying Aug 13, 2021
51b80ee
style
wangraying Aug 13, 2021
4d390e1
Merge branch 'cached-dataset' of https://github.com/BaguaSys/bagua in…
wangraying Aug 13, 2021
7a6f357
update after review
wangraying Aug 13, 2021
4894642
doc
wangraying Aug 13, 2021
b262801
d
wangraying Aug 13, 2021
15f2977
share redis server
wangraying Aug 13, 2021
d7c493d
Merge branch 'cached-dataset' of https://github.com/BaguaSys/bagua in…
wangraying Aug 13, 2021
0027c97
.
wangraying Aug 13, 2021
32c066d
Update tests/contrib/test_cached_dataset.py
wangraying Aug 13, 2021
c9022bc
devdev
wangraying Aug 13, 2021
1e21deb
update
wangraying Aug 13, 2021
792ab7d
.
wangraying Aug 13, 2021
4c0b046
Update bagua/torch_api/contrib/cached_dataset.py
wangraying Aug 13, 2021
cd8afbe
fmt
wangraying Aug 13, 2021
0f49c2e
.
wangraying Aug 16, 2021
d3fd996
Merge branch 'master' into cached-dataset
wangraying Aug 16, 2021
ec684a8
add
wangraying Aug 19, 2021
09d4f45
Update cache_loader.py
NOBLES5E Aug 23, 2021
a86eb95
Update cached_dataset.py
NOBLES5E Aug 23, 2021
3291b42
Update cache_loader.py
NOBLES5E Aug 23, 2021
5a67148
Update redis_store.py
NOBLES5E Aug 23, 2021
4553c3b
Update store.py
NOBLES5E Aug 23, 2021
8359617
Update store.py
NOBLES5E Aug 23, 2021
58b8e64
Update redis_store.py
NOBLES5E Aug 23, 2021
34d7672
Update redis_store.py
NOBLES5E Aug 23, 2021
24cd3c7
update doc
wangraying Aug 23, 2021
348bae1
..
wangraying Aug 23, 2021
c94f67c
.
wangraying Aug 23, 2021
1a97a30
format
wangraying Aug 24, 2021
db0ed14
.
wangraying Aug 24, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bagua/torch_api/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
LoadBalancingDistributedSampler,
LoadBalancingDistributedBatchSampler,
)
from .cache_loader import CacheLoader # noqa: F401
from .cached_dataset import CachedDataset # noqa: F401
139 changes: 139 additions & 0 deletions bagua/torch_api/contrib/cache_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import pickle
from collections import defaultdict
from typing import Callable


__all__ = ["CacheLoader"]


def serialize(input):
return pickle.dumps(input)


def deserialize(input):
return pickle.loads(input)


class CacheLoader:
NOBLES5E marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
backend: str = "redis",
dataset_name: str = "",
writer_buffer_size: int = 1,
**kwargs,
):
"""
Cache loader caches values calculated by an expensive function by theirs keys via :meth:`get`,
so that the values can be retrieved faster next time.

Internally, values are indexed by ``"{dataset_name}_{key}"`` and saved in a distributed key-value
store, where ``dataset_name`` is specified on initializing, and ``key`` is the argument in :meth:`get`.

By default, cache loader uses :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` as its backend distributed key-value store implementation. It
supports using a list of existing redis servers or spawning new redis servers. Parameters for :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` can be provided here in
``**kwargs``.

Args:
backend(str): Backend distributed key-value store implementation. Can be ``"redis"``.
dataset_name(str): Name of the dataset. Default ``""``.
writer_buffer_size(int): Number of samples to collect before writing to the backend key-value store.
Useful for improving the backend throughput.

Example::
To use a list of existing redis servers for the "redis" backend:

>>> from bagua.torch_api.contrib import CacheLoader
>>>
>>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}]
>>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, dataset_name="test")
>>>
>>> loader.get(index, lambda x: items[x])

To spawn new redis servers for the "redis" backend:

>>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000)

.. note::
Cache loaders with the same ``dataset_name`` will reuse and overwrite each other's cache.
Use different ``dataset_name`` if this is not desired.

"""

self.backend = backend
self.dataset_name = dataset_name

if backend == "redis":
from .utils.redis_store import RedisStore

self.store = RedisStore(**kwargs)
else:
raise ValueError('Invalid backend, only support "redis" currently')

self.fetcher = BatchFetcher(self.store, 1, writer_buffer_size)

def get(self, key: str, load_fn: Callable[[str], None]):
"""
Returns the value associated with ``key`` in cache, use ``load_fn`` to create the entry if the key does not exist
in the cache. ``load_fn`` is a function taking ``key`` as its argument, and returning corresponding value to
wangraying marked this conversation as resolved.
Show resolved Hide resolved
be cached.
"""

cache_key = "{}_{}".format(self.dataset_name, key)
ret = self.fetcher.read(cache_key)

if ret is None:
ret = load_fn(key)
# write to store
self.fetcher.write(cache_key, ret)
return ret

def num_keys(self):
"""Returns the number of keys in the cache."""

return self.store.num_keys()


class BatchFetcher:
def __init__(self, store, read_buffer_size, writer_buffer_size):
self.store = store
self.read_buffer_size = max(1, read_buffer_size)
self.writer_buffer_size = max(1, writer_buffer_size)

self.write_map = defaultdict()
self.write_cnt = 0
self.read_cnt = 0

self.last_write_tms = None

def read(self, key):
self.read_cnt += 1

try:
ret = self.store.get(key)
except Exception:
ret = None
else:
self.write_post_read()

if ret is not None:
return deserialize(ret)
return ret

def write(self, key, value):
self.write_cnt += 1

self.write_map[key] = serialize(value)
if self.write_cnt % self.writer_buffer_size == 0:
self.flush_write_map()

def write_post_read(self):
if self.read_cnt % 1000 == 0 and len(self.write_map) > 0:
self.flush_write_map()

def flush_write_map(self):
try:
self.store.mset(self.write_map)
except Exception:
pass
else:
self.write_map.clear()
62 changes: 62 additions & 0 deletions bagua/torch_api/contrib/cached_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from torch.utils.data.dataset import Dataset
from .cache_loader import CacheLoader

__all__ = ["CachedDataset"]


class CachedDataset(Dataset):
def __init__(
self,
dataset: Dataset,
backend: str = "redis",
dataset_name: str = "",
writer_buffer_size: int = 20,
**kwargs,
):
"""
Cached dataset wraps a `PyTorch dataset <https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset>`_
to cache its samples in memory, so that accessing these samples after the
first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading
the dataset itself is slow, which could slow down the whole training process.

Internally, the samples are indexed by a string key ``"{dataset_name}_{index}"`` and saved in a distributed key-value
store, where ``dataset_name`` is specified when initializing the cached dataset, and ``index`` is the index
of a specific sample (the argument of :meth:`__getitem__` method in a PyTorch dataset).

Args:
dataset: PyTorch dataset to be wrapped.
backend(str): Backend distributed key-value store implementation. Can be ``redis``.
dataset_name(str): Name of the dataset. Default ``""``.
writer_buffer_size(int): Number of samples to collect before writing to the backend key-value store.
Useful for improving the backend throughput.

Example::

>>> from bagua.torch_api.contrib import CachedDataset
>>> cache_dataset = CachedDataset(dataset, backend="redis", dataset_name="ds")
>>> dataloader = torch.utils.data.DataLoader(cached_dataset)

.. note::
Cached dataset is a special case of cache loader. Parameter ``backend`` and ``writer_buffer_size`` in
initializing a cached dataset have the same meanings as those in initializing a cache loader. You can
provide the arguments for cache loader here in ``**kwargs``. See also :class:`~bagua.torch_api.contrib.cache_loader.CacheLoader`.

"""

self.dataset = dataset

self.cache_loader = CacheLoader(
backend,
dataset_name,
writer_buffer_size,
**kwargs,
)
"""
The backend cache loader instance.
"""

def __getitem__(self, item):
return self.cache_loader.get(item, lambda x: self.dataset[x])

def __len__(self):
return len(self.dataset)
1 change: 1 addition & 0 deletions bagua/torch_api/contrib/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ["redis_store", "store"]
wangraying marked this conversation as resolved.
Show resolved Hide resolved
Loading