From 62325c6c607cdac6689ac14ffe981cc2d570a4a4 Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 28 Jul 2021 07:35:52 +0000 Subject: [PATCH 01/63] init commit --- bagua/torch_api/contrib/cached_dataset.py | 44 +++++++ bagua/torch_api/contrib/utils/__init__.py | 0 bagua/torch_api/contrib/utils/lmdb_store.py | 52 ++++++++ bagua/torch_api/contrib/utils/redis_store.py | 131 +++++++++++++++++++ bagua/torch_api/contrib/utils/store.py | 24 ++++ tests/contrib/test_cached_dataset.py | 23 ++++ tests/contrib/test_store.py | 54 ++++++++ 7 files changed, 328 insertions(+) create mode 100644 bagua/torch_api/contrib/cached_dataset.py create mode 100644 bagua/torch_api/contrib/utils/__init__.py create mode 100644 bagua/torch_api/contrib/utils/lmdb_store.py create mode 100644 bagua/torch_api/contrib/utils/redis_store.py create mode 100644 bagua/torch_api/contrib/utils/store.py create mode 100644 tests/contrib/test_cached_dataset.py create mode 100644 tests/contrib/test_store.py diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py new file mode 100644 index 000000000..41ca91c00 --- /dev/null +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -0,0 +1,44 @@ +from torch.utils.data.dataset import Dataset +import pyarrow as pa + + +def serialize(input): + try: + return pa.serialize(input).to_buffer() + except Exception as e: + raise RuntimeError("Serialization error!") + +def deserialize(input): + try: + return pa.deserialize(input) + except Exception as e: + raise RuntimeError("Deserialization error!") + + +class CachedDataset(Dataset): + def __init__(self, dataset: Dataset, backend: str="redis", **kwargs): + self.dataset = dataset + self.backend = backend + + if backend == "redis": + from .utils.redis_store import RedisStore + self.store = RedisStore(**kwargs) + elif backend == "lmdb": + from .utils.lmdb_store import LmdbStore + self.store = LmdbStore(**kwargs) + else: + raise ValueError("invalid backend, only support \"redis\" and \"lmdb\" at present") + + def __getitem__(self, item): + value = self.store.get(str(item)) + + if value is not None: + return value + + # write to store + value = self.dataset[item] + self.store.set(str(item), serialize(value)) + return value + + def __len__(self): + return len(self.dataset) diff --git a/bagua/torch_api/contrib/utils/__init__.py b/bagua/torch_api/contrib/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bagua/torch_api/contrib/utils/lmdb_store.py b/bagua/torch_api/contrib/utils/lmdb_store.py new file mode 100644 index 000000000..a604e8cdd --- /dev/null +++ b/bagua/torch_api/contrib/utils/lmdb_store.py @@ -0,0 +1,52 @@ +import lmdb +from .store import Store +from typing import List, Dict, Optional + + +class LmdbStore(Store): + def __init__(self, name, map_size: int = 1_000_000_000): + self.map_size = map_size + self.name = name + self.db = lmdb.open(self.name, map_size=self.map_size) + + def set(self, key: str, value: str): + with self.db.begin(write=True) as txn: + txn.put(key, value) + + def get(self, key: str) -> Optional[str]: + with self.db.begin(write=False) as txn: + return txn.get(key) + + def num_keys(self) -> int: + return self.db.stat()["entries"] + + def clear(self): + # TODO + raise NotImplementedError("not implemented in `LmdbStore`") + + def mset(self, mapping: Dict[str, str]): + kvpairs = list(zip(mapping.keys(), mapping.values())) + + with self.db.begin(write=True) as txn: + cursor = txn.cursor() + consumed_cnt, added_cnt = cursor.putmulti(kvpairs) + + if consumed_cnt != added_cnt: + raise RuntimeError( + "LmdbStore mset failed with: {}, failed to set {} items".format( + mapping, consumed_cnt - added_cnt + ) + ) + + def mget(self, keys: List[str]) -> List[Optional[str]]: + + with self.db.begin(write=False) as txn: + cursor = txn.cursor() + kvpairs = cursor.getmulti(keys) + + mapping = {k: v for k, v in kvpairs} + return list(map(lambda k: mapping.get(k, None), keys)) + + def status(self) -> bool: + # TODO + raise NotImplementedError("not implemented in `LmdbStore`") diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py new file mode 100644 index 000000000..93c1643a6 --- /dev/null +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -0,0 +1,131 @@ +import socket +import subprocess +import time +from bagua.torch_api.env import get_rank, get_local_rank, get_world_size, get_local_size +from rediscluster import RedisCluster +from redis import Redis +from typing import List, Dict, Optional +from .store import Store +import torch.distributed.distributed_c10d as c10d +import torch +import pickle +import logging +import redis + + +class RedisStore(Store): + def __init__(self, bootstrap=True, hosts: List[Dict[str, str]] = None): + if not bootstrap and (hosts is None or len(hosts) == 0): + raise ValueError("Must provide `hosts` when bootstrap is `False`") + + if bootstrap: + if hosts is not None and len(hosts) > 0: + logging.warn("Ignore input `hosts` when bootstrap is `True`") + hosts = [] + + self.cluster_mode = True + self.hosts = hosts + + if bootstrap: + self._start_redis_cluster() + + if self.cluster_mode: + self.client = RedisCluster(startup_nodes=hosts, decode_responses=True) + else: + self.client = Redis(host=self.hosts[0]["host"], port=self.hosts[0]["port"]) + + assert self.client.ping() + + def set(self, key: str, value: str): + self.client.set(key, value) + + def get(self, key: str) -> Optional[str]: + return self.client.get(key) + + def num_keys(self) -> int: + return sum(self.client.dbsize().values()) + + def clear(self): + self.client.flushdb() + + def mset(self, mapping: Dict[str, str]): + self.client.mset(mapping) + + def mget(self, keys: List[str]) -> List[Optional[str]]: + return self.client.mget(keys) + + def status(self) -> bool: + return self.client.ping() + + def _start_redis_cluster(self): + nrank = get_rank() // get_local_size() + nnodes = get_world_size() // get_local_size() + + ip, port = get_host_ip(), find_free_port() + if not torch.distributed.is_initialized() or nnodes == 1: + start_redis_server_cli(port, False) + self.hosts.append({"host": "127.0.0.1", "port": port}) + self.cluster_mode = False + return + + default_store = c10d._get_default_store() + + key_pattern = "redis-node{}" + if get_local_rank() == 0: + start_redis_server_cli(port, True) + content = {"host": ip, "port": port} + default_store.set(key_pattern.format(nrank), pickle.dumps(content)) + + for i in range(nnodes): + ret = default_store.get(key_pattern.format(i)) + print(ret) + self.hosts.append(ret) + + create_redis_cluster_cli(self.hosts) + + +def create_redis_cluster_cli(hosts: List[Dict[str, str]]): + cmd = ["redis-cli", "--cluster", "create"] + + for h in hosts: + cmd.append("{}:{}".format(h["host"], h["port"])) + + logging.debug(f"create redis cluster, command: {cmd}") + + subprocess.run(cmd, capture_output=True, text=True, input="yes") + time.sleep(5) + + +def start_redis_server_cli(port, cluster_mode, *args): + cmd = ["redis-server", "--daemonize yes", "--port {}".format(port)] + + if cluster_mode: + cluster_config = [ + "--cluster-enabled yes", + "--cluster-config-file nodes.conf", + "--cluster-node-timeout 5000", + "--appendonly yes", + ] + cmd.extend(cluster_config) + + cmd.extend(list(args)) + logging.debug(f"start redis server, command: {cmd}") + subprocess.run(cmd) + time.sleep(10) + + +def find_free_port(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", 0)) + sockname = sock.getsockname() + sock.close() + return sockname[1] + + +def get_host_ip(): + try: + host_name = socket.gethostname() + return socket.gethostbyname(host_name) + except: + print("Unable to get host IP") diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py new file mode 100644 index 000000000..54b529bec --- /dev/null +++ b/bagua/torch_api/contrib/utils/store.py @@ -0,0 +1,24 @@ +from typing import List, Dict, Optional + + +class Store: + def set(self, key: str, value: str): + pass + + def get(self, key: str) -> Optional[str]: + pass + + def num_keys(self) -> int: + pass + + def clear(self): + pass + + def mset(self, mapping: Dict[str, str]): + pass + + def mget(self, keys: List[str]) -> List[Optional[str]]: + pass + + def status(self) -> bool: + pass diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py new file mode 100644 index 000000000..6572341c7 --- /dev/null +++ b/tests/contrib/test_cached_dataset.py @@ -0,0 +1,23 @@ +if __name__ == "__main__": + n = 10 + + class TestDataset(Dataset): + def __getitem__(self, item): + if item < 10: + return (np.random.rand(5, 2), np.random.rand(1)) + raise IndexError("xxx") + + def __len__(self): + return 10 + + dataset = TestDataset() + + for i, data in enumerate(dataset): + print(i, data) + + # with CachedDataset(dataset) as ms: + ms = CachedDataset(dataset) + print(len(ms)) + for _ in range(100): + for i, data in enumerate(ms): + print(i, data) diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py new file mode 100644 index 000000000..7b292fd40 --- /dev/null +++ b/tests/contrib/test_store.py @@ -0,0 +1,54 @@ +import unittest +from bagua.torch_api.contrib.utils.redis_store import ( + RedisStore, + start_redis_server_cli, + create_redis_cluster_cli, + find_free_port, +) +from bagua.torch_api.contrib.utils.lmdb_store import LmdbStore +import logging + +logging.basicConfig(level=logging.DEBUG) + + +class TestStore(unittest.TestCase): + def check(self, store): + store.set(b"Beijing", b"China") + store.set(b"Paris", b"France") + + store.mset({b"New Delhi": b"India", b"Tokyo": b"Japan", b"Madrid": b"Spain"}) + ret = store.mget([b"Beijing", b"London", b"Tokyo"]) + self.assertEqual(ret[0], b"China") + self.assertEqual(ret[1], None) + self.assertEqual(ret[2], b"Japan") + + r1 = store.get(b"Madrid") + r2 = store.get(b"Shanghai") + self.assertEqual(r1, b"Spain") + self.assertEqual(r2, None) + + def test_redis_store(self): + store = RedisStore(bootstrap=True) + self.check(store) + + def test_redis_cluster_store(self): + n = 3 + hosts = [] + for i in range(n): + port = find_free_port() + start_redis_server_cli( + port, True, f"--cluster-config-file nodes{port}.conf" + ) + hosts.append({"host": "127.0.0.1", "port": port}) + + create_redis_cluster_cli(hosts=hosts) + store = RedisStore(hosts=hosts, bootstrap=False) + self.check(store) + + def test_lmdb_store(self): + store = LmdbStore(name=".test.lmdb") + self.check(store) + + +if __name__ == "__main__": + unittest.main() From 2c640bac0d302f0b3c5c1b27217b687baf88440b Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 28 Jul 2021 17:19:54 +0800 Subject: [PATCH 02/63] add --- bagua/torch_api/contrib/cached_dataset.py | 33 ++++++----- bagua/torch_api/contrib/utils/redis_store.py | 6 +- tests/contrib/test_cached_dataset.py | 60 ++++++++++++++------ 3 files changed, 63 insertions(+), 36 deletions(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 41ca91c00..965242c7b 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -1,44 +1,43 @@ from torch.utils.data.dataset import Dataset -import pyarrow as pa +import pickle def serialize(input): - try: - return pa.serialize(input).to_buffer() - except Exception as e: - raise RuntimeError("Serialization error!") + return pickle.dumps(input) + def deserialize(input): - try: - return pa.deserialize(input) - except Exception as e: - raise RuntimeError("Deserialization error!") + return pickle.loads(input) class CachedDataset(Dataset): - def __init__(self, dataset: Dataset, backend: str="redis", **kwargs): + def __init__(self, dataset: Dataset, backend: str = "redis", **kwargs): self.dataset = dataset self.backend = backend if backend == "redis": from .utils.redis_store import RedisStore + self.store = RedisStore(**kwargs) elif backend == "lmdb": from .utils.lmdb_store import LmdbStore + self.store = LmdbStore(**kwargs) else: - raise ValueError("invalid backend, only support \"redis\" and \"lmdb\" at present") + raise ValueError( + 'invalid backend, only support "redis" and "lmdb" at present' + ) def __getitem__(self, item): - value = self.store.get(str(item)) + ret = self.store.get(str(item).encode()) - if value is not None: - return value + if ret is not None: + return deserialize(ret) # write to store - value = self.dataset[item] - self.store.set(str(item), serialize(value)) - return value + ret = self.dataset[item] + self.store.set(str(item).encode(), serialize(ret)) + return ret def __len__(self): return len(self.dataset) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 93c1643a6..711c6571d 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -43,7 +43,11 @@ def get(self, key: str) -> Optional[str]: return self.client.get(key) def num_keys(self) -> int: - return sum(self.client.dbsize().values()) + return ( + sum(self.client.dbsize().values()) + if self.cluster_mode + else self.client.dbsize() + ) def clear(self): self.client.flushdb() diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 6572341c7..7cdd221f4 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -1,23 +1,47 @@ -if __name__ == "__main__": - n = 10 +from bagua.torch_api.contrib.cached_dataset import CachedDataset +from torch.utils.data.dataset import Dataset +import numpy as np +import logging +import unittest + +logging.basicConfig(level=logging.DEBUG) + + +class TestDataset(Dataset): + def __init__(self, size): + self.size = size + self.dataset = [(np.random.rand(5, 2), np.random.rand(1)) for _ in range(size)] + + def __getitem__(self, item): + return self.dataset[item] - class TestDataset(Dataset): - def __getitem__(self, item): - if item < 10: - return (np.random.rand(5, 2), np.random.rand(1)) - raise IndexError("xxx") + def __len__(self): + return self.size - def __len__(self): - return 10 - dataset = TestDataset() +class TestCachedDataset(unittest.TestCase): + def check_dataset(self, dataset, cached_dataset): + for step, data in enumerate(cached_dataset): + pass - for i, data in enumerate(dataset): - print(i, data) + self.assertEqual(cached_dataset.store.num_keys(), 10) + for i in range(10): + print(i, dataset[i][0], cached_dataset[i][0]) + self.assertTrue((dataset[i][0] == cached_dataset[i][0]).all()) + self.assertTrue((dataset[i][1] == cached_dataset[i][1]).all()) - # with CachedDataset(dataset) as ms: - ms = CachedDataset(dataset) - print(len(ms)) - for _ in range(100): - for i, data in enumerate(ms): - print(i, data) + def test_lmdb(self): + np.random.seed(0) + dataset = TestDataset(10) + cached_dataset = CachedDataset(dataset, backend="lmdb", name=".test.lmdb") + self.check_dataset(dataset, cached_dataset) + + def test_redis(self): + np.random.seed(0) + dataset = TestDataset(10) + cached_dataset = CachedDataset(dataset, backend="redis") + self.check_dataset(dataset, cached_dataset) + + +if __name__ == "__main__": + unittest.main() From 9dda8cc51e28cb8dad6ec8334de64dc5b7a947ce Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 28 Jul 2021 19:54:38 +0800 Subject: [PATCH 03/63] update --- bagua/torch_api/contrib/cached_dataset.py | 9 +++-- bagua/torch_api/contrib/utils/lmdb_store.py | 35 ++++++++++++-------- bagua/torch_api/contrib/utils/redis_store.py | 23 ++++++++++--- bagua/torch_api/contrib/utils/store.py | 5 ++- tests/contrib/test_cached_dataset.py | 10 ++++-- tests/contrib/test_store.py | 21 +++++++++--- 6 files changed, 73 insertions(+), 30 deletions(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 965242c7b..0821542f4 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -11,18 +11,18 @@ def deserialize(input): class CachedDataset(Dataset): - def __init__(self, dataset: Dataset, backend: str = "redis", **kwargs): + def __init__(self, dataset: Dataset, backend: str = "redis", overwrite=True, capacity: int = 10_000_000_000, **kwargs): self.dataset = dataset self.backend = backend if backend == "redis": from .utils.redis_store import RedisStore - self.store = RedisStore(**kwargs) + self.store = RedisStore(overwrite=overwrite, capacity=capacity, **kwargs) elif backend == "lmdb": from .utils.lmdb_store import LmdbStore - self.store = LmdbStore(**kwargs) + self.store = LmdbStore(overwrite=overwrite, capacity=capacity, **kwargs) else: raise ValueError( 'invalid backend, only support "redis" and "lmdb" at present' @@ -41,3 +41,6 @@ def __getitem__(self, item): def __len__(self): return len(self.dataset) + + def cleanup(self): + self.store.shutdown() diff --git a/bagua/torch_api/contrib/utils/lmdb_store.py b/bagua/torch_api/contrib/utils/lmdb_store.py index a604e8cdd..a23e0dbce 100644 --- a/bagua/torch_api/contrib/utils/lmdb_store.py +++ b/bagua/torch_api/contrib/utils/lmdb_store.py @@ -4,30 +4,37 @@ class LmdbStore(Store): - def __init__(self, name, map_size: int = 1_000_000_000): - self.map_size = map_size + def __init__(self, name, capacity: int = 1_000_000_000, overwrite=True): self.name = name - self.db = lmdb.open(self.name, map_size=self.map_size) + self.capacity = capacity + self.env = lmdb.open(self.name, map_size=self.capacity) + + if overwrite: + self.clear() def set(self, key: str, value: str): - with self.db.begin(write=True) as txn: + with self.env.begin(write=True) as txn: txn.put(key, value) def get(self, key: str) -> Optional[str]: - with self.db.begin(write=False) as txn: + with self.env.begin(write=False) as txn: return txn.get(key) def num_keys(self) -> int: - return self.db.stat()["entries"] + return self.env.stat()["entries"] + + def clear(self) -> bool: + db = self.env.open_db() - def clear(self): - # TODO - raise NotImplementedError("not implemented in `LmdbStore`") + with self.env.begin(write=True) as txn: + txn.drop(db) + + return self.num_keys() def mset(self, mapping: Dict[str, str]): kvpairs = list(zip(mapping.keys(), mapping.values())) - with self.db.begin(write=True) as txn: + with self.env.begin(write=True) as txn: cursor = txn.cursor() consumed_cnt, added_cnt = cursor.putmulti(kvpairs) @@ -40,7 +47,7 @@ def mset(self, mapping: Dict[str, str]): def mget(self, keys: List[str]) -> List[Optional[str]]: - with self.db.begin(write=False) as txn: + with self.env.begin(write=False) as txn: cursor = txn.cursor() kvpairs = cursor.getmulti(keys) @@ -48,5 +55,7 @@ def mget(self, keys: List[str]) -> List[Optional[str]]: return list(map(lambda k: mapping.get(k, None), keys)) def status(self) -> bool: - # TODO - raise NotImplementedError("not implemented in `LmdbStore`") + return self.env.stat() + + def shutdown(self): + self.env.close() diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 711c6571d..a0654230a 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -14,7 +14,13 @@ class RedisStore(Store): - def __init__(self, bootstrap=True, hosts: List[Dict[str, str]] = None): + def __init__( + self, + capacity: int = 1_000_000_000, + bootstrap=True, + hosts: List[Dict[str, str]] = None, + overwrite=True, + ): if not bootstrap and (hosts is None or len(hosts) == 0): raise ValueError("Must provide `hosts` when bootstrap is `False`") @@ -23,6 +29,7 @@ def __init__(self, bootstrap=True, hosts: List[Dict[str, str]] = None): logging.warn("Ignore input `hosts` when bootstrap is `True`") hosts = [] + self.capacity = capacity self.cluster_mode = True self.hosts = hosts @@ -35,6 +42,8 @@ def __init__(self, bootstrap=True, hosts: List[Dict[str, str]] = None): self.client = Redis(host=self.hosts[0]["host"], port=self.hosts[0]["port"]) assert self.client.ping() + if overwrite: + self.clear() def set(self, key: str, value: str): self.client.set(key, value) @@ -49,7 +58,7 @@ def num_keys(self) -> int: else self.client.dbsize() ) - def clear(self): + def clear(self) -> bool: self.client.flushdb() def mset(self, mapping: Dict[str, str]): @@ -61,13 +70,17 @@ def mget(self, keys: List[str]) -> List[Optional[str]]: def status(self) -> bool: return self.client.ping() + def shutdown(self): + self.client.shutdown() + def _start_redis_cluster(self): nrank = get_rank() // get_local_size() nnodes = get_world_size() // get_local_size() + capacity = (self.capacity + nnodes - 1) // nnodes ip, port = get_host_ip(), find_free_port() if not torch.distributed.is_initialized() or nnodes == 1: - start_redis_server_cli(port, False) + start_redis_server_cli(port, False, "--maxmemory {}".format(capacity)) self.hosts.append({"host": "127.0.0.1", "port": port}) self.cluster_mode = False return @@ -76,7 +89,7 @@ def _start_redis_cluster(self): key_pattern = "redis-node{}" if get_local_rank() == 0: - start_redis_server_cli(port, True) + start_redis_server_cli(port, True, "--maxmemory {}".format(capacity)) content = {"host": ip, "port": port} default_store.set(key_pattern.format(nrank), pickle.dumps(content)) @@ -132,4 +145,4 @@ def get_host_ip(): host_name = socket.gethostname() return socket.gethostbyname(host_name) except: - print("Unable to get host IP") + raise RuntimeError("Unable to get host IP") diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 54b529bec..156fb339f 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -11,7 +11,7 @@ def get(self, key: str) -> Optional[str]: def num_keys(self) -> int: pass - def clear(self): + def cleanup(self) -> bool: pass def mset(self, mapping: Dict[str, str]): @@ -22,3 +22,6 @@ def mget(self, keys: List[str]) -> List[Optional[str]]: def status(self) -> bool: pass + + def shutdown(self): + pass diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 7cdd221f4..0d8e301de 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -26,21 +26,25 @@ def check_dataset(self, dataset, cached_dataset): self.assertEqual(cached_dataset.store.num_keys(), 10) for i in range(10): - print(i, dataset[i][0], cached_dataset[i][0]) self.assertTrue((dataset[i][0] == cached_dataset[i][0]).all()) self.assertTrue((dataset[i][1] == cached_dataset[i][1]).all()) def test_lmdb(self): np.random.seed(0) dataset = TestDataset(10) - cached_dataset = CachedDataset(dataset, backend="lmdb", name=".test.lmdb") + cached_dataset = CachedDataset( + dataset, backend="lmdb", name=".test.lmdb", overwrite=True + ) self.check_dataset(dataset, cached_dataset) + cached_dataset.cleanup() + def test_redis(self): np.random.seed(0) dataset = TestDataset(10) - cached_dataset = CachedDataset(dataset, backend="redis") + cached_dataset = CachedDataset(dataset, backend="redis", overwrite=True) self.check_dataset(dataset, cached_dataset) + cached_dataset.cleanup() if __name__ == "__main__": diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index 7b292fd40..e3b015141 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -27,6 +27,21 @@ def check(self, store): self.assertEqual(r1, b"Spain") self.assertEqual(r2, None) + cnt = store.num_keys() + self.assertEqual(cnt, 5) + + store.clear() + self.assertEqual(store.num_keys(), 0) + + self.assertTrue(store.status()) + + # shut down resources at the end + store.shutdown() + + def test_lmdb_store(self): + store = LmdbStore(name=".test.lmdb", overwrite=True) + self.check(store) + def test_redis_store(self): store = RedisStore(bootstrap=True) self.check(store) @@ -42,11 +57,7 @@ def test_redis_cluster_store(self): hosts.append({"host": "127.0.0.1", "port": port}) create_redis_cluster_cli(hosts=hosts) - store = RedisStore(hosts=hosts, bootstrap=False) - self.check(store) - - def test_lmdb_store(self): - store = LmdbStore(name=".test.lmdb") + store = RedisStore(hosts=hosts, bootstrap=False, overwrite=True) self.check(store) From d41667ebe83f304a100060e88b251e10e6dae831 Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 28 Jul 2021 23:35:38 +0800 Subject: [PATCH 04/63] update redis config --- bagua/torch_api/contrib/cached_dataset.py | 17 +++++-- bagua/torch_api/contrib/utils/lmdb_store.py | 6 +-- bagua/torch_api/contrib/utils/redis_store.py | 30 ++++++++---- tests/contrib/test_store.py | 50 +++++++++++++++++++- 4 files changed, 85 insertions(+), 18 deletions(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 0821542f4..ab777734b 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -11,18 +11,29 @@ def deserialize(input): class CachedDataset(Dataset): - def __init__(self, dataset: Dataset, backend: str = "redis", overwrite=True, capacity: int = 10_000_000_000, **kwargs): + def __init__( + self, + dataset: Dataset, + backend: str = "redis", + overwrite=True, + capacity_per_node: int = 10_000_000_000, + **kwargs, + ): self.dataset = dataset self.backend = backend if backend == "redis": from .utils.redis_store import RedisStore - self.store = RedisStore(overwrite=overwrite, capacity=capacity, **kwargs) + self.store = RedisStore( + overwrite=overwrite, capacity_per_node=capacity_per_node, **kwargs + ) elif backend == "lmdb": from .utils.lmdb_store import LmdbStore - self.store = LmdbStore(overwrite=overwrite, capacity=capacity, **kwargs) + self.store = LmdbStore( + overwrite=overwrite, capacity_per_node=capacity_per_node, **kwargs + ) else: raise ValueError( 'invalid backend, only support "redis" and "lmdb" at present' diff --git a/bagua/torch_api/contrib/utils/lmdb_store.py b/bagua/torch_api/contrib/utils/lmdb_store.py index a23e0dbce..cfc8bb483 100644 --- a/bagua/torch_api/contrib/utils/lmdb_store.py +++ b/bagua/torch_api/contrib/utils/lmdb_store.py @@ -4,10 +4,10 @@ class LmdbStore(Store): - def __init__(self, name, capacity: int = 1_000_000_000, overwrite=True): + def __init__(self, name, capacity_per_node: int = 1_000_000_000, overwrite=True): self.name = name - self.capacity = capacity - self.env = lmdb.open(self.name, map_size=self.capacity) + self.capacity_per_node = capacity_per_node + self.env = lmdb.open(self.name, map_size=self.capacity_per_node) if overwrite: self.clear() diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index a0654230a..34db48a7a 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -16,7 +16,7 @@ class RedisStore(Store): def __init__( self, - capacity: int = 1_000_000_000, + capacity_per_node: int = 1_000_000_000, bootstrap=True, hosts: List[Dict[str, str]] = None, overwrite=True, @@ -29,7 +29,7 @@ def __init__( logging.warn("Ignore input `hosts` when bootstrap is `True`") hosts = [] - self.capacity = capacity + self.capacity_per_node = capacity_per_node self.cluster_mode = True self.hosts = hosts @@ -71,31 +71,34 @@ def status(self) -> bool: return self.client.ping() def shutdown(self): - self.client.shutdown() + if hasattr(self, "_client_on_host") and self._client_on_host is not None: + # shutdown redis server bootstrapped locally + self._client_on_host.shutdown(nosave=True) + self.client.close() def _start_redis_cluster(self): nrank = get_rank() // get_local_size() nnodes = get_world_size() // get_local_size() - capacity = (self.capacity + nnodes - 1) // nnodes ip, port = get_host_ip(), find_free_port() if not torch.distributed.is_initialized() or nnodes == 1: - start_redis_server_cli(port, False, "--maxmemory {}".format(capacity)) + start_redis_server_cli(port, False, self.capacity_per_node) self.hosts.append({"host": "127.0.0.1", "port": port}) self.cluster_mode = False + self._client_on_host = Redis(port=port) return default_store = c10d._get_default_store() key_pattern = "redis-node{}" if get_local_rank() == 0: - start_redis_server_cli(port, True, "--maxmemory {}".format(capacity)) + start_redis_server_cli(port, True, self.capacity_per_node) content = {"host": ip, "port": port} default_store.set(key_pattern.format(nrank), pickle.dumps(content)) + self._client_on_host = Redis(port=port) for i in range(nnodes): ret = default_store.get(key_pattern.format(i)) - print(ret) self.hosts.append(ret) create_redis_cluster_cli(self.hosts) @@ -113,15 +116,22 @@ def create_redis_cluster_cli(hosts: List[Dict[str, str]]): time.sleep(5) -def start_redis_server_cli(port, cluster_mode, *args): - cmd = ["redis-server", "--daemonize yes", "--port {}".format(port)] +def start_redis_server_cli(port, cluster_mode, capacity, *args): + cmd = [ + "redis-server", + "--daemonize yes", + "--port {}".format(port), + "--maxmemory {}".format(capacity), + "--maxmemory-policy allkeys-random", # use random eviction by default + "--appendonly no", # disable persistence by default + '--save ""', + ] if cluster_mode: cluster_config = [ "--cluster-enabled yes", "--cluster-config-file nodes.conf", "--cluster-node-timeout 5000", - "--appendonly yes", ] cmd.extend(cluster_config) diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index e3b015141..9ef94483a 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -6,6 +6,8 @@ find_free_port, ) from bagua.torch_api.contrib.utils.lmdb_store import LmdbStore +import redis +import multiprocessing as mp import logging logging.basicConfig(level=logging.DEBUG) @@ -46,20 +48,64 @@ def test_redis_store(self): store = RedisStore(bootstrap=True) self.check(store) + +class TestClusterStore(unittest.TestCase): + def check(self, store): + store.set("a", 1) + + store.mset({"b": 2, "c": 3}) + ret = store.mget(["b", "d"]) + self.assertEqual(ret[0], str(2)) + self.assertEqual(ret[1], None) + + r1 = store.get("a") + r2 = store.get("d") + self.assertEqual(r1, str(1)) + self.assertEqual(r2, None) + + cnt = store.num_keys() + self.assertEqual(cnt, 3) + + store.clear() + self.assertEqual(store.num_keys(), 0) + + self.assertTrue(store.status()) + + # try to shut down resources + store.shutdown() + + self.assertTrue(store.status()) + def test_redis_cluster_store(self): n = 3 hosts = [] + ports = [] + processes = [] for i in range(n): port = find_free_port() - start_redis_server_cli( - port, True, f"--cluster-config-file nodes{port}.conf" + p = mp.Process( + target=start_redis_server_cli, + args=(port, True, 10000000, f"--cluster-config-file nodes{port}.conf"), ) + p.start() + + ports.append(port) + processes.append(p) hosts.append({"host": "127.0.0.1", "port": port}) + for p in processes: + p.join() + create_redis_cluster_cli(hosts=hosts) store = RedisStore(hosts=hosts, bootstrap=False, overwrite=True) self.check(store) + # Now shut down servers safely + for port in ports: + client = redis.Redis(port=port) + client.shutdown(nosave=True) + client.close() + if __name__ == "__main__": unittest.main() From 5d130b735a2f936945020d4d586ada4a7a42e4c9 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 29 Jul 2021 15:33:19 +0800 Subject: [PATCH 05/63] . --- bagua/torch_api/contrib/__init__.py | 1 + bagua/torch_api/contrib/cached_dataset.py | 18 +-- bagua/torch_api/contrib/utils/lmdb_store.py | 6 +- bagua/torch_api/contrib/utils/redis_store.py | 122 +++++++++++-------- tests/contrib/test_cached_dataset.py | 8 +- tests/contrib/test_store.py | 46 ++++--- 6 files changed, 114 insertions(+), 87 deletions(-) diff --git a/bagua/torch_api/contrib/__init__.py b/bagua/torch_api/contrib/__init__.py index 7594c089c..96786a381 100644 --- a/bagua/torch_api/contrib/__init__.py +++ b/bagua/torch_api/contrib/__init__.py @@ -3,3 +3,4 @@ LoadBalancingDistributedSampler, LoadBalancingDistributedBatchSampler, ) +from .cached_dataset import CachedDataset diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index ab777734b..aa3fe2e61 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -15,39 +15,39 @@ def __init__( self, dataset: Dataset, backend: str = "redis", - overwrite=True, capacity_per_node: int = 10_000_000_000, + key_prefix: str = "", **kwargs, ): + """ """ self.dataset = dataset self.backend = backend + self.key_prefix = key_prefix if backend == "redis": from .utils.redis_store import RedisStore - self.store = RedisStore( - overwrite=overwrite, capacity_per_node=capacity_per_node, **kwargs - ) + self.store = RedisStore(capacity_per_node=capacity_per_node, **kwargs) elif backend == "lmdb": from .utils.lmdb_store import LmdbStore - self.store = LmdbStore( - overwrite=overwrite, capacity_per_node=capacity_per_node, **kwargs - ) + self.store = LmdbStore(capacity_per_node=capacity_per_node, **kwargs) else: raise ValueError( 'invalid backend, only support "redis" and "lmdb" at present' ) def __getitem__(self, item): - ret = self.store.get(str(item).encode()) + key = "{}{}".format(self.key_prefix, item).encode() + + ret = self.store.get(key) if ret is not None: return deserialize(ret) # write to store ret = self.dataset[item] - self.store.set(str(item).encode(), serialize(ret)) + self.store.set(key, serialize(ret)) return ret def __len__(self): diff --git a/bagua/torch_api/contrib/utils/lmdb_store.py b/bagua/torch_api/contrib/utils/lmdb_store.py index cfc8bb483..735d8fcd2 100644 --- a/bagua/torch_api/contrib/utils/lmdb_store.py +++ b/bagua/torch_api/contrib/utils/lmdb_store.py @@ -4,10 +4,10 @@ class LmdbStore(Store): - def __init__(self, name, capacity_per_node: int = 1_000_000_000, overwrite=True): - self.name = name + def __init__(self, path, capacity_per_node: int = 1_000_000_000, overwrite=False): + self.path = path self.capacity_per_node = capacity_per_node - self.env = lmdb.open(self.name, map_size=self.capacity_per_node) + self.env = lmdb.open(self.path, map_size=self.capacity_per_node) if overwrite: self.clear() diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 34db48a7a..263939429 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -2,48 +2,52 @@ import subprocess import time from bagua.torch_api.env import get_rank, get_local_rank, get_world_size, get_local_size -from rediscluster import RedisCluster + +# from rediscluster import RedisCluster from redis import Redis from typing import List, Dict, Optional from .store import Store import torch.distributed.distributed_c10d as c10d import torch -import pickle +import json import logging -import redis + +_host_ip = None class RedisStore(Store): def __init__( self, - capacity_per_node: int = 1_000_000_000, - bootstrap=True, hosts: List[Dict[str, str]] = None, - overwrite=True, + cluster_mode: bool = False, + capacity_per_node: int = 1_000_000_000, ): - if not bootstrap and (hosts is None or len(hosts) == 0): - raise ValueError("Must provide `hosts` when bootstrap is `False`") + """ """ - if bootstrap: - if hosts is not None and len(hosts) > 0: - logging.warn("Ignore input `hosts` when bootstrap is `True`") - hosts = [] + self.hosts = [] + if hosts is None: + logging.info("Ready to bootstrap redis server locally") + self.bootstrap = True + else: + logging.info("Ready to connect redis servers: {}".format(hosts)) + self.bootstrap = False + self.hosts.extends(hosts) + self.cluster_mode = cluster_mode self.capacity_per_node = capacity_per_node - self.cluster_mode = True - self.hosts = hosts - if bootstrap: - self._start_redis_cluster() + if self.bootstrap: + self._bootstrap_redis_server() if self.cluster_mode: - self.client = RedisCluster(startup_nodes=hosts, decode_responses=True) + raise ValueError("RedisStore does not support cluster mode at present") + # self.client = RedisCluster(startup_nodes=self.hosts, decode_responses=True) else: - self.client = Redis(host=self.hosts[0]["host"], port=self.hosts[0]["port"]) + self.client = create_redis_client( + host=self.hosts[0]["host"], port=self.hosts[0]["port"] + ) assert self.client.ping() - if overwrite: - self.clear() def set(self, key: str, value: str): self.client.set(key, value) @@ -52,11 +56,7 @@ def get(self, key: str) -> Optional[str]: return self.client.get(key) def num_keys(self) -> int: - return ( - sum(self.client.dbsize().values()) - if self.cluster_mode - else self.client.dbsize() - ) + return self.client.dbsize() def clear(self) -> bool: self.client.flushdb() @@ -71,37 +71,37 @@ def status(self) -> bool: return self.client.ping() def shutdown(self): - if hasattr(self, "_client_on_host") and self._client_on_host is not None: - # shutdown redis server bootstrapped locally - self._client_on_host.shutdown(nosave=True) - self.client.close() - - def _start_redis_cluster(self): - nrank = get_rank() // get_local_size() - nnodes = get_world_size() // get_local_size() + if self.bootstrap: + self.client.shutdown(nosave=True) + def _bootstrap_redis_server(self): ip, port = get_host_ip(), find_free_port() - if not torch.distributed.is_initialized() or nnodes == 1: - start_redis_server_cli(port, False, self.capacity_per_node) - self.hosts.append({"host": "127.0.0.1", "port": port}) - self.cluster_mode = False - self._client_on_host = Redis(port=port) - return + hostinfo = {"host": ip, "port": port} + if get_local_rank() == 0: + start_redis_server_cli(port, self.cluster_mode, self.capacity_per_node) + else: + wait_for_start_redis_server_cli() - default_store = c10d._get_default_store() + if get_world_size() > 1: + nrank = get_rank() // get_local_size() + nnodes = get_world_size() // get_local_size() + default_store = c10d._get_default_store() + key_pattern = "redis-node{}" - key_pattern = "redis-node{}" - if get_local_rank() == 0: - start_redis_server_cli(port, True, self.capacity_per_node) - content = {"host": ip, "port": port} - default_store.set(key_pattern.format(nrank), pickle.dumps(content)) - self._client_on_host = Redis(port=port) + if get_local_rank() == 0: + default_store.set(key_pattern.format(nrank), json.dumps(hostinfo)) - for i in range(nnodes): - ret = default_store.get(key_pattern.format(i)) - self.hosts.append(ret) + for i in range(nnodes): + ret = json.loads(default_store.get(key_pattern.format(i))) + self.hosts.append(ret) + else: + self.hosts.append(hostinfo) - create_redis_cluster_cli(self.hosts) + if not self.cluster_mode: + return + + # create_redis_cluster_cli(self.hosts) + # wait_for_create_redis_cluster_cli() def create_redis_cluster_cli(hosts: List[Dict[str, str]]): @@ -116,6 +116,10 @@ def create_redis_cluster_cli(hosts: List[Dict[str, str]]): time.sleep(5) +def wait_for_create_redis_cluster_cli(): + time.sleep(5) + + def start_redis_server_cli(port, cluster_mode, capacity, *args): cmd = [ "redis-server", @@ -141,6 +145,14 @@ def start_redis_server_cli(port, cluster_mode, capacity, *args): time.sleep(10) +def wait_for_start_redis_server_cli(): + time.sleep(10) + + +def create_redis_client(host, port): + return Redis(port=port) if host == get_host_ip() else Redis(host=host, port=port) + + def find_free_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -151,8 +163,10 @@ def find_free_port(): def get_host_ip(): - try: + global _host_ip + + if _host_ip is None: host_name = socket.gethostname() - return socket.gethostbyname(host_name) - except: - raise RuntimeError("Unable to get host IP") + _host_ip = socket.gethostbyname(host_name) + + return _host_ip diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 0d8e301de..35dca492a 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -21,7 +21,7 @@ def __len__(self): class TestCachedDataset(unittest.TestCase): def check_dataset(self, dataset, cached_dataset): - for step, data in enumerate(cached_dataset): + for _, _ in enumerate(cached_dataset): pass self.assertEqual(cached_dataset.store.num_keys(), 10) @@ -33,7 +33,7 @@ def test_lmdb(self): np.random.seed(0) dataset = TestDataset(10) cached_dataset = CachedDataset( - dataset, backend="lmdb", name=".test.lmdb", overwrite=True + dataset, backend="lmdb", path=".lmdb", overwrite=True ) self.check_dataset(dataset, cached_dataset) @@ -42,7 +42,9 @@ def test_lmdb(self): def test_redis(self): np.random.seed(0) dataset = TestDataset(10) - cached_dataset = CachedDataset(dataset, backend="redis", overwrite=True) + cached_dataset = CachedDataset( + dataset, backend="redis", hosts=None, cluster_mode=False + ) self.check_dataset(dataset, cached_dataset) cached_dataset.cleanup() diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index 9ef94483a..de16973c0 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -9,11 +9,13 @@ import redis import multiprocessing as mp import logging +import numpy as np +import pickle logging.basicConfig(level=logging.DEBUG) -class TestStore(unittest.TestCase): +class TestLmdbStore(unittest.TestCase): def check(self, store): store.set(b"Beijing", b"China") store.set(b"Paris", b"France") @@ -41,26 +43,28 @@ def check(self, store): store.shutdown() def test_lmdb_store(self): - store = LmdbStore(name=".test.lmdb", overwrite=True) - self.check(store) - - def test_redis_store(self): - store = RedisStore(bootstrap=True) + store = LmdbStore(path=".lmdb", capacity_per_node=10000000, overwrite=True) self.check(store) -class TestClusterStore(unittest.TestCase): +class TestRedisStore(unittest.TestCase): def check(self, store): - store.set("a", 1) - - store.mset({"b": 2, "c": 3}) - ret = store.mget(["b", "d"]) - self.assertEqual(ret[0], str(2)) + self.generated_data = [np.random.rand(10) for _ in range(5)] + store.set("1", pickle.dumps(self.generated_data[1])) + + store.mset( + { + "2": pickle.dumps(self.generated_data[2]), + "3": pickle.dumps(self.generated_data[3]), + } + ) + ret = store.mget(["2", "4"]) + self.assertTrue((pickle.loads(ret[0]) == self.generated_data[2]).all()) self.assertEqual(ret[1], None) - r1 = store.get("a") - r2 = store.get("d") - self.assertEqual(r1, str(1)) + r1 = store.get("1") + r2 = store.get("4") + self.assertTrue((pickle.loads(r1) == self.generated_data[1]).all()) self.assertEqual(r2, None) cnt = store.num_keys() @@ -71,12 +75,15 @@ def check(self, store): self.assertTrue(store.status()) + def test_redis_store(self): + store = RedisStore(hosts=None, cluster_mode=False, capacity_per_node=10000000) + self.check(store) + # try to shut down resources store.shutdown() - self.assertTrue(store.status()) - def test_redis_cluster_store(self): + return n = 3 hosts = [] ports = [] @@ -97,9 +104,12 @@ def test_redis_cluster_store(self): p.join() create_redis_cluster_cli(hosts=hosts) - store = RedisStore(hosts=hosts, bootstrap=False, overwrite=True) + + store = RedisStore(hosts=hosts, cluster_mode=True, capacity_per_node=10000000) self.check(store) + self.assertTrue(store.status()) + # Now shut down servers safely for port in ports: client = redis.Redis(port=port) From b6f36d2f1ff0518175847e669e2d08dc8badbbc9 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 29 Jul 2021 17:58:59 +0800 Subject: [PATCH 06/63] fix --- bagua/torch_api/contrib/utils/redis_store.py | 27 ++++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 263939429..5daba4499 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -39,13 +39,7 @@ def __init__( if self.bootstrap: self._bootstrap_redis_server() - if self.cluster_mode: - raise ValueError("RedisStore does not support cluster mode at present") - # self.client = RedisCluster(startup_nodes=self.hosts, decode_responses=True) - else: - self.client = create_redis_client( - host=self.hosts[0]["host"], port=self.hosts[0]["port"] - ) + self.client = self.create_redis_client(self.hosts) assert self.client.ping() @@ -103,6 +97,21 @@ def _bootstrap_redis_server(self): # create_redis_cluster_cli(self.hosts) # wait_for_create_redis_cluster_cli() + def create_redis_client(self, hosts): + if self.cluster_mode: + raise ValueError("RedisStore does not support cluster mode at present") + # self.client = RedisCluster(startup_nodes=self.hosts, decode_responses=True) + else: + nrank = get_rank() // get_local_size() + hostinfo = hosts[nrank % len(self.hosts)] + + logging.debug(f"{get_host_ip()} connect to redis server: {hostinfo}") + return ( + Redis(port=hostinfo["port"]) + if hostinfo["host"] == get_host_ip() + else Redis(host=host["info"], port=hostinfo["port"]) + ) + def create_redis_cluster_cli(hosts: List[Dict[str, str]]): cmd = ["redis-cli", "--cluster", "create"] @@ -149,10 +158,6 @@ def wait_for_start_redis_server_cli(): time.sleep(10) -def create_redis_client(host, port): - return Redis(port=port) if host == get_host_ip() else Redis(host=host, port=port) - - def find_free_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) From ec998cc23c3e309a9a1b9922372ae8edeae3d20a Mon Sep 17 00:00:00 2001 From: ritaw Date: Sat, 31 Jul 2021 19:51:15 +0800 Subject: [PATCH 07/63] batch writes --- bagua/torch_api/contrib/cached_dataset.py | 54 +++++++++++++++++--- bagua/torch_api/contrib/utils/redis_store.py | 2 +- tests/contrib/test_cached_dataset.py | 15 +++--- 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index aa3fe2e61..929936925 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -1,5 +1,6 @@ from torch.utils.data.dataset import Dataset import pickle +from collections import defaultdict def serialize(input): @@ -15,8 +16,10 @@ def __init__( self, dataset: Dataset, backend: str = "redis", - capacity_per_node: int = 10_000_000_000, + capacity_per_node: int = 100_000_000_000, key_prefix: str = "", + batch_reads: int = 1, + batch_writes: int = 50, **kwargs, ): """ """ @@ -37,17 +40,17 @@ def __init__( 'invalid backend, only support "redis" and "lmdb" at present' ) + self.fetcher = BatchFetcher(self.store, batch_reads, batch_writes) + def __getitem__(self, item): key = "{}{}".format(self.key_prefix, item).encode() - ret = self.store.get(key) - - if ret is not None: - return deserialize(ret) + ret = self.fetcher.read(key) - # write to store - ret = self.dataset[item] - self.store.set(key, serialize(ret)) + if ret == None: + ret = self.dataset[item] + # write to store + self.fetcher.write(key, ret) return ret def __len__(self): @@ -55,3 +58,38 @@ def __len__(self): def cleanup(self): self.store.shutdown() + + +class BatchFetcher: + def __init__(self, store, batch_reads=1, batch_writes=50): + self.store = store + self.batch_reads = batch_reads + self.batch_writes = batch_writes + + self.write_map = defaultdict() + self.write_cnt = 0 + self.read_cnt = 0 + + self.last_write_tms = None + + def read(self, key): + self.read_cnt += 1 + + ret = self.store.get(key) + self.write_post_read() + if ret is not None: + return deserialize(ret) + return ret + + def write(self, key, value): + self.write_cnt += 1 + + self.write_map[key] = serialize(value) + if self.write_cnt % self.batch_writes == 0: + self.store.mset(self.write_map) + self.write_map.clear() + + def write_post_read(self): + if self.read_cnt % 1000 == 0 and len(self.write_map) > 0: + self.store.mset(self.write_map) + self.write_map.clear() diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 5daba4499..82c2e9301 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -20,7 +20,7 @@ def __init__( self, hosts: List[Dict[str, str]] = None, cluster_mode: bool = False, - capacity_per_node: int = 1_000_000_000, + capacity_per_node: int = 100_000_000_000, ): """ """ diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 35dca492a..6c265c6ca 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -21,17 +21,17 @@ def __len__(self): class TestCachedDataset(unittest.TestCase): def check_dataset(self, dataset, cached_dataset): - for _, _ in enumerate(cached_dataset): - pass + for _ in range(10): + for _, _ in enumerate(cached_dataset): + pass - self.assertEqual(cached_dataset.store.num_keys(), 10) - for i in range(10): + self.assertEqual(cached_dataset.store.num_keys(), len(dataset)) + for i in range(len(dataset)): self.assertTrue((dataset[i][0] == cached_dataset[i][0]).all()) self.assertTrue((dataset[i][1] == cached_dataset[i][1]).all()) def test_lmdb(self): - np.random.seed(0) - dataset = TestDataset(10) + dataset = TestDataset(102) cached_dataset = CachedDataset( dataset, backend="lmdb", path=".lmdb", overwrite=True ) @@ -40,8 +40,7 @@ def test_lmdb(self): cached_dataset.cleanup() def test_redis(self): - np.random.seed(0) - dataset = TestDataset(10) + dataset = TestDataset(102) cached_dataset = CachedDataset( dataset, backend="redis", hosts=None, cluster_mode=False ) From 6c6ab3dd025c4ea7a587172571cc0dfe4705c0af Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 4 Aug 2021 10:06:03 +0000 Subject: [PATCH 08/63] add cluster --- bagua/torch_api/contrib/__init__.py | 3 +- bagua/torch_api/contrib/cache_dataset.py | 33 ++++ .../{cached_dataset.py => cache_loader.py} | 34 ++-- bagua/torch_api/contrib/utils/hash_func.py | 50 ++++++ bagua/torch_api/contrib/utils/lmdb_store.py | 61 -------- bagua/torch_api/contrib/utils/redis_store.py | 147 ++++++++---------- bagua/torch_api/contrib/utils/store.py | 80 +++++++++- tests/contrib/test_cached_dataset.py | 28 ++-- tests/contrib/test_store.py | 64 ++------ 9 files changed, 268 insertions(+), 232 deletions(-) create mode 100644 bagua/torch_api/contrib/cache_dataset.py rename bagua/torch_api/contrib/{cached_dataset.py => cache_loader.py} (70%) create mode 100644 bagua/torch_api/contrib/utils/hash_func.py delete mode 100644 bagua/torch_api/contrib/utils/lmdb_store.py diff --git a/bagua/torch_api/contrib/__init__.py b/bagua/torch_api/contrib/__init__.py index 96786a381..c0b42217b 100644 --- a/bagua/torch_api/contrib/__init__.py +++ b/bagua/torch_api/contrib/__init__.py @@ -3,4 +3,5 @@ LoadBalancingDistributedSampler, LoadBalancingDistributedBatchSampler, ) -from .cached_dataset import CachedDataset +from .cache_loader import CacheLoader +from .cache_dataset import CacheDataset diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py new file mode 100644 index 000000000..a71892c12 --- /dev/null +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -0,0 +1,33 @@ +from torch.utils.data.dataset import Dataset +from .cache_loader import CacheLoader + + +class CacheDataset(Dataset): + def __init__( + self, + dataset: Dataset, + backend: str = "redis", + capacity_per_node: int = 100_000_000_000, + key_prefix: str = "", + batch_reads: int = 1, + batch_writes: int = 20, + **kwargs, + ): + """ """ + + self.dataset = dataset + + self.cache_loader = CacheLoader( + backend, + capacity_per_node, + key_prefix, + batch_reads, + batch_writes, + **kwargs, + ) + + def __getitem__(self, item): + return self.cache_loader.get(item, lambda x: self.dataset[x]) + + def __len__(self): + return len(self.dataset) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cache_loader.py similarity index 70% rename from bagua/torch_api/contrib/cached_dataset.py rename to bagua/torch_api/contrib/cache_loader.py index 929936925..9d63cf84b 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -1,4 +1,3 @@ -from torch.utils.data.dataset import Dataset import pickle from collections import defaultdict @@ -11,30 +10,25 @@ def deserialize(input): return pickle.loads(input) -class CachedDataset(Dataset): +class CacheLoader: def __init__( self, - dataset: Dataset, backend: str = "redis", capacity_per_node: int = 100_000_000_000, key_prefix: str = "", batch_reads: int = 1, - batch_writes: int = 50, + batch_writes: int = 20, **kwargs, ): """ """ - self.dataset = dataset self.backend = backend + self.capacity_per_node = capacity_per_node self.key_prefix = key_prefix if backend == "redis": from .utils.redis_store import RedisStore self.store = RedisStore(capacity_per_node=capacity_per_node, **kwargs) - elif backend == "lmdb": - from .utils.lmdb_store import LmdbStore - - self.store = LmdbStore(capacity_per_node=capacity_per_node, **kwargs) else: raise ValueError( 'invalid backend, only support "redis" and "lmdb" at present' @@ -42,29 +36,29 @@ def __init__( self.fetcher = BatchFetcher(self.store, batch_reads, batch_writes) - def __getitem__(self, item): - key = "{}{}".format(self.key_prefix, item).encode() - - ret = self.fetcher.read(key) + def get(self, key, load_fn): + cache_key = "{}{}".format(self.key_prefix, key).encode() + ret = self.fetcher.read(cache_key) if ret == None: - ret = self.dataset[item] + ret = load_fn(key) # write to store - self.fetcher.write(key, ret) + self.fetcher.write(cache_key, ret) return ret - def __len__(self): - return len(self.dataset) + def num_keys(self): + return self.store.num_keys() def cleanup(self): + # TODO: cleanup automatically self.store.shutdown() class BatchFetcher: - def __init__(self, store, batch_reads=1, batch_writes=50): + def __init__(self, store, batch_reads, batch_writes): self.store = store - self.batch_reads = batch_reads - self.batch_writes = batch_writes + self.batch_reads = max(1, batch_reads) + self.batch_writes = max(1, batch_writes) self.write_map = defaultdict() self.write_cnt = 0 diff --git a/bagua/torch_api/contrib/utils/hash_func.py b/bagua/torch_api/contrib/utils/hash_func.py new file mode 100644 index 000000000..4b1fa789b --- /dev/null +++ b/bagua/torch_api/contrib/utils/hash_func.py @@ -0,0 +1,50 @@ +# reference: https://github.com/lammertb/libcrc/blob/master/src/crc16.c + + +# fmt: off +table = [ + 0x0000,0x1021,0x2042,0x3063,0x4084,0x50a5,0x60c6,0x70e7, + 0x8108,0x9129,0xa14a,0xb16b,0xc18c,0xd1ad,0xe1ce,0xf1ef, + 0x1231,0x0210,0x3273,0x2252,0x52b5,0x4294,0x72f7,0x62d6, + 0x9339,0x8318,0xb37b,0xa35a,0xd3bd,0xc39c,0xf3ff,0xe3de, + 0x2462,0x3443,0x0420,0x1401,0x64e6,0x74c7,0x44a4,0x5485, + 0xa56a,0xb54b,0x8528,0x9509,0xe5ee,0xf5cf,0xc5ac,0xd58d, + 0x3653,0x2672,0x1611,0x0630,0x76d7,0x66f6,0x5695,0x46b4, + 0xb75b,0xa77a,0x9719,0x8738,0xf7df,0xe7fe,0xd79d,0xc7bc, + 0x48c4,0x58e5,0x6886,0x78a7,0x0840,0x1861,0x2802,0x3823, + 0xc9cc,0xd9ed,0xe98e,0xf9af,0x8948,0x9969,0xa90a,0xb92b, + 0x5af5,0x4ad4,0x7ab7,0x6a96,0x1a71,0x0a50,0x3a33,0x2a12, + 0xdbfd,0xcbdc,0xfbbf,0xeb9e,0x9b79,0x8b58,0xbb3b,0xab1a, + 0x6ca6,0x7c87,0x4ce4,0x5cc5,0x2c22,0x3c03,0x0c60,0x1c41, + 0xedae,0xfd8f,0xcdec,0xddcd,0xad2a,0xbd0b,0x8d68,0x9d49, + 0x7e97,0x6eb6,0x5ed5,0x4ef4,0x3e13,0x2e32,0x1e51,0x0e70, + 0xff9f,0xefbe,0xdfdd,0xcffc,0xbf1b,0xaf3a,0x9f59,0x8f78, + 0x9188,0x81a9,0xb1ca,0xa1eb,0xd10c,0xc12d,0xf14e,0xe16f, + 0x1080,0x00a1,0x30c2,0x20e3,0x5004,0x4025,0x7046,0x6067, + 0x83b9,0x9398,0xa3fb,0xb3da,0xc33d,0xd31c,0xe37f,0xf35e, + 0x02b1,0x1290,0x22f3,0x32d2,0x4235,0x5214,0x6277,0x7256, + 0xb5ea,0xa5cb,0x95a8,0x8589,0xf56e,0xe54f,0xd52c,0xc50d, + 0x34e2,0x24c3,0x14a0,0x0481,0x7466,0x6447,0x5424,0x4405, + 0xa7db,0xb7fa,0x8799,0x97b8,0xe75f,0xf77e,0xc71d,0xd73c, + 0x26d3,0x36f2,0x0691,0x16b0,0x6657,0x7676,0x4615,0x5634, + 0xd94c,0xc96d,0xf90e,0xe92f,0x99c8,0x89e9,0xb98a,0xa9ab, + 0x5844,0x4865,0x7806,0x6827,0x18c0,0x08e1,0x3882,0x28a3, + 0xcb7d,0xdb5c,0xeb3f,0xfb1e,0x8bf9,0x9bd8,0xabbb,0xbb9a, + 0x4a75,0x5a54,0x6a37,0x7a16,0x0af1,0x1ad0,0x2ab3,0x3a92, + 0xfd2e,0xed0f,0xdd6c,0xcd4d,0xbdaa,0xad8b,0x9de8,0x8dc9, + 0x7c26,0x6c07,0x5c64,0x4c45,0x3ca2,0x2c83,0x1ce0,0x0cc1, + 0xef1f,0xff3e,0xcf5d,0xdf7c,0xaf9b,0xbfba,0x8fd9,0x9ff8, + 0x6e17,0x7e36,0x4e55,0x5e74,0x2e93,0x3eb2,0x0ed1,0x1ef0 +] + + +def crc16(data: bytes): + hash_code = 0x000 + + for i in data: + hash_code = (hash_code >> 8) ^ table[(hash_code ^ i) & 0xFF] + + return hash_code + +if __name__ == "__main__": + print(crc16(b'abc')) diff --git a/bagua/torch_api/contrib/utils/lmdb_store.py b/bagua/torch_api/contrib/utils/lmdb_store.py deleted file mode 100644 index 735d8fcd2..000000000 --- a/bagua/torch_api/contrib/utils/lmdb_store.py +++ /dev/null @@ -1,61 +0,0 @@ -import lmdb -from .store import Store -from typing import List, Dict, Optional - - -class LmdbStore(Store): - def __init__(self, path, capacity_per_node: int = 1_000_000_000, overwrite=False): - self.path = path - self.capacity_per_node = capacity_per_node - self.env = lmdb.open(self.path, map_size=self.capacity_per_node) - - if overwrite: - self.clear() - - def set(self, key: str, value: str): - with self.env.begin(write=True) as txn: - txn.put(key, value) - - def get(self, key: str) -> Optional[str]: - with self.env.begin(write=False) as txn: - return txn.get(key) - - def num_keys(self) -> int: - return self.env.stat()["entries"] - - def clear(self) -> bool: - db = self.env.open_db() - - with self.env.begin(write=True) as txn: - txn.drop(db) - - return self.num_keys() - - def mset(self, mapping: Dict[str, str]): - kvpairs = list(zip(mapping.keys(), mapping.values())) - - with self.env.begin(write=True) as txn: - cursor = txn.cursor() - consumed_cnt, added_cnt = cursor.putmulti(kvpairs) - - if consumed_cnt != added_cnt: - raise RuntimeError( - "LmdbStore mset failed with: {}, failed to set {} items".format( - mapping, consumed_cnt - added_cnt - ) - ) - - def mget(self, keys: List[str]) -> List[Optional[str]]: - - with self.env.begin(write=False) as txn: - cursor = txn.cursor() - kvpairs = cursor.getmulti(keys) - - mapping = {k: v for k, v in kvpairs} - return list(map(lambda k: mapping.get(k, None), keys)) - - def status(self) -> bool: - return self.env.stat() - - def shutdown(self): - self.env.close() diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 82c2e9301..319c6531d 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -2,20 +2,18 @@ import subprocess import time from bagua.torch_api.env import get_rank, get_local_rank, get_world_size, get_local_size - -# from rediscluster import RedisCluster from redis import Redis from typing import List, Dict, Optional -from .store import Store +from .store import Store, ClusterStore import torch.distributed.distributed_c10d as c10d -import torch import json import logging + _host_ip = None -class RedisStore(Store): +class RedisStore(ClusterStore): def __init__( self, hosts: List[Dict[str, str]] = None, @@ -31,7 +29,7 @@ def __init__( else: logging.info("Ready to connect redis servers: {}".format(hosts)) self.bootstrap = False - self.hosts.extends(hosts) + self.hosts.extend(hosts) self.cluster_mode = cluster_mode self.capacity_per_node = capacity_per_node @@ -39,45 +37,24 @@ def __init__( if self.bootstrap: self._bootstrap_redis_server() - self.client = self.create_redis_client(self.hosts) - - assert self.client.ping() - - def set(self, key: str, value: str): - self.client.set(key, value) - - def get(self, key: str) -> Optional[str]: - return self.client.get(key) - - def num_keys(self) -> int: - return self.client.dbsize() - - def clear(self) -> bool: - self.client.flushdb() - - def mset(self, mapping: Dict[str, str]): - self.client.mset(mapping) - - def mget(self, keys: List[str]) -> List[Optional[str]]: - return self.client.mget(keys) - - def status(self) -> bool: - return self.client.ping() + stores = [] + for h in self.hosts: + store = _RedisStore( + host=h["host"], port=h["port"], bootstrap=self.bootstrap + ) + stores.append(store) - def shutdown(self): - if self.bootstrap: - self.client.shutdown(nosave=True) + super(RedisStore, self).__init__(stores) def _bootstrap_redis_server(self): ip, port = get_host_ip(), find_free_port() hostinfo = {"host": ip, "port": port} if get_local_rank() == 0: - start_redis_server_cli(port, self.cluster_mode, self.capacity_per_node) - else: - wait_for_start_redis_server_cli() + start_redis_server_cli(port, self.capacity_per_node) + hosts = [] + nrank = get_rank() // get_local_size() if get_world_size() > 1: - nrank = get_rank() // get_local_size() nnodes = get_world_size() // get_local_size() default_store = c10d._get_default_store() key_pattern = "redis-node{}" @@ -87,49 +64,71 @@ def _bootstrap_redis_server(self): for i in range(nnodes): ret = json.loads(default_store.get(key_pattern.format(i))) - self.hosts.append(ret) + hosts.append(ret) else: - self.hosts.append(hostinfo) - - if not self.cluster_mode: - return - - # create_redis_cluster_cli(self.hosts) - # wait_for_create_redis_cluster_cli() + hosts.append(hostinfo) - def create_redis_client(self, hosts): if self.cluster_mode: - raise ValueError("RedisStore does not support cluster mode at present") - # self.client = RedisCluster(startup_nodes=self.hosts, decode_responses=True) + self.hosts.extend(hosts) else: - nrank = get_rank() // get_local_size() - hostinfo = hosts[nrank % len(self.hosts)] - - logging.debug(f"{get_host_ip()} connect to redis server: {hostinfo}") - return ( - Redis(port=hostinfo["port"]) - if hostinfo["host"] == get_host_ip() - else Redis(host=host["info"], port=hostinfo["port"]) - ) + self.hosts.append(hosts[nrank]) + + +class _RedisStore(Store): + def __init__(self, host, port, bootstrap): + self.client = create_redis_client(host=host, port=port) + self.bootstrap = bootstrap + + assert self._connect_with_retry( + retry_times=3 + ), "Could not connect to redis server {}:{}".format(host, port) + + def _connect_with_retry(self, retry_times=3): + for i in range(retry_times): + try: + connected = self.client.ping() + except Exception as e: + if i == retry_times - 1: + return False + + time.sleep(10) + else: + return connected + + return False + + def set(self, key: str, value: str): + self.client.set(key, value) + + def get(self, key: str) -> Optional[str]: + return self.client.get(key) + + def num_keys(self) -> int: + return self.client.dbsize() + def clear(self) -> bool: + self.client.flushdb() -def create_redis_cluster_cli(hosts: List[Dict[str, str]]): - cmd = ["redis-cli", "--cluster", "create"] + def mset(self, mapping: Dict[str, str]): + self.client.mset(mapping) - for h in hosts: - cmd.append("{}:{}".format(h["host"], h["port"])) + def mget(self, keys: List[str]) -> List[Optional[str]]: + return self.client.mget(keys) - logging.debug(f"create redis cluster, command: {cmd}") + def status(self) -> bool: + return self.client.ping() - subprocess.run(cmd, capture_output=True, text=True, input="yes") - time.sleep(5) + def shutdown(self): + if self.bootstrap: + self.client.shutdown(nosave=True) -def wait_for_create_redis_cluster_cli(): - time.sleep(5) +def create_redis_client(host, port): + logging.debug(f"{get_host_ip()} connect to redis server: {host}:{port}") + return Redis(port=port) if host == get_host_ip() else Redis(host=host, port=port) -def start_redis_server_cli(port, cluster_mode, capacity, *args): +def start_redis_server_cli(port, capacity, *args): cmd = [ "redis-server", "--daemonize yes", @@ -138,24 +137,12 @@ def start_redis_server_cli(port, cluster_mode, capacity, *args): "--maxmemory-policy allkeys-random", # use random eviction by default "--appendonly no", # disable persistence by default '--save ""', + "--protected-mode no" ] - if cluster_mode: - cluster_config = [ - "--cluster-enabled yes", - "--cluster-config-file nodes.conf", - "--cluster-node-timeout 5000", - ] - cmd.extend(cluster_config) - cmd.extend(list(args)) logging.debug(f"start redis server, command: {cmd}") subprocess.run(cmd) - time.sleep(10) - - -def wait_for_start_redis_server_cli(): - time.sleep(10) def find_free_port(): diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 156fb339f..2b10daaf3 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -1,4 +1,6 @@ -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Any +from .hash_func import crc16 +from collections import defaultdict class Store: @@ -11,7 +13,7 @@ def get(self, key: str) -> Optional[str]: def num_keys(self) -> int: pass - def cleanup(self) -> bool: + def clear(self) -> bool: pass def mset(self, mapping: Dict[str, str]): @@ -25,3 +27,77 @@ def status(self) -> bool: def shutdown(self): pass + + +class ClusterStore(Store): + def __init__(self, stores: List[Store]): + self.stores = stores + self.num_stores = len(stores) + + def _hash_key(self, key): + hash_code = crc16(key) + return hash_code % self.num_stores + + def route(self, key) -> Store: + return ( + self.stores[self._hash_key(key)] if self.num_stores > 1 else self.stores[0] + ) + + def set(self, key: str, value: str): + if self.num_stores == 1: + return self.stores[0].set(key, value) + + self.route(key).set(key, value) + + def get(self, key: str) -> Optional[str]: + if self.num_stores == 1: + return self.stores[0].get(key) + + return self.route(key).get(key) + + def num_keys(self) -> int: + return sum([store.num_keys() for store in self.stores]) + + def clear(self) -> bool: + for store in self.stores: + store.clear() + + def mset(self, mapping: Dict[str, str]): + if self.num_stores == 1: + return self.stores[0].mset(mapping) + + route_table = {} + for k, v in mapping.items(): + sid = self._hash_key(k) + m = route_table.get(sid, defaultdict(dict)) + m[k] = v + route_table[sid] = m + + for sid, m in route_table.items(): + self.stores[sid].mset(m) + + def mget(self, keys: List[str]) -> List[Optional[str]]: + if self.num_stores == 1: + return self.stores[0].mget(keys) + + route_table = {} + for k in keys: + sid = self._hash_key(k) + l = route_table.get(sid, []) + l.append(k) + route_table[sid] = l + + result_map = {} + for sid, l in route_table.items(): + ret = self.stores[sid].mget(l) + m = {k: v for k, v in zip(l, ret)} + result_map = {**result_map, **m} + + return list(map(lambda x: result_map.get(x, None), keys)) + + def status(self) -> bool: + return all([store.status() for store in self.stores]) + + def shutdown(self): + for store in self.stores: + store.shutdown() diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 6c265c6ca..74ab263df 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -1,4 +1,4 @@ -from bagua.torch_api.contrib.cached_dataset import CachedDataset +from bagua.torch_api.contrib.cache_dataset import CacheDataset from torch.utils.data.dataset import Dataset import numpy as np import logging @@ -19,33 +19,23 @@ def __len__(self): return self.size -class TestCachedDataset(unittest.TestCase): - def check_dataset(self, dataset, cached_dataset): +class TestCacheDataset(unittest.TestCase): + def check_dataset(self, dataset, cache_dataset): for _ in range(10): - for _, _ in enumerate(cached_dataset): + for _, _ in enumerate(cache_dataset): pass - self.assertEqual(cached_dataset.store.num_keys(), len(dataset)) + self.assertEqual(cache_dataset.cache_loader.num_keys(), len(dataset)) for i in range(len(dataset)): - self.assertTrue((dataset[i][0] == cached_dataset[i][0]).all()) - self.assertTrue((dataset[i][1] == cached_dataset[i][1]).all()) - - def test_lmdb(self): - dataset = TestDataset(102) - cached_dataset = CachedDataset( - dataset, backend="lmdb", path=".lmdb", overwrite=True - ) - self.check_dataset(dataset, cached_dataset) - - cached_dataset.cleanup() + self.assertTrue((dataset[i][0] == cache_dataset[i][0]).all()) + self.assertTrue((dataset[i][1] == cache_dataset[i][1]).all()) def test_redis(self): dataset = TestDataset(102) - cached_dataset = CachedDataset( + cache_dataset = CacheDataset( dataset, backend="redis", hosts=None, cluster_mode=False ) - self.check_dataset(dataset, cached_dataset) - cached_dataset.cleanup() + self.check_dataset(dataset, cache_dataset) if __name__ == "__main__": diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index de16973c0..5328aa8f5 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -2,51 +2,18 @@ from bagua.torch_api.contrib.utils.redis_store import ( RedisStore, start_redis_server_cli, - create_redis_cluster_cli, find_free_port, ) -from bagua.torch_api.contrib.utils.lmdb_store import LmdbStore import redis import multiprocessing as mp import logging import numpy as np import pickle +import time logging.basicConfig(level=logging.DEBUG) -class TestLmdbStore(unittest.TestCase): - def check(self, store): - store.set(b"Beijing", b"China") - store.set(b"Paris", b"France") - - store.mset({b"New Delhi": b"India", b"Tokyo": b"Japan", b"Madrid": b"Spain"}) - ret = store.mget([b"Beijing", b"London", b"Tokyo"]) - self.assertEqual(ret[0], b"China") - self.assertEqual(ret[1], None) - self.assertEqual(ret[2], b"Japan") - - r1 = store.get(b"Madrid") - r2 = store.get(b"Shanghai") - self.assertEqual(r1, b"Spain") - self.assertEqual(r2, None) - - cnt = store.num_keys() - self.assertEqual(cnt, 5) - - store.clear() - self.assertEqual(store.num_keys(), 0) - - self.assertTrue(store.status()) - - # shut down resources at the end - store.shutdown() - - def test_lmdb_store(self): - store = LmdbStore(path=".lmdb", capacity_per_node=10000000, overwrite=True) - self.check(store) - - class TestRedisStore(unittest.TestCase): def check(self, store): self.generated_data = [np.random.rand(10) for _ in range(5)] @@ -56,34 +23,35 @@ def check(self, store): { "2": pickle.dumps(self.generated_data[2]), "3": pickle.dumps(self.generated_data[3]), + "4": pickle.dumps(self.generated_data[4]), } ) - ret = store.mget(["2", "4"]) - self.assertTrue((pickle.loads(ret[0]) == self.generated_data[2]).all()) - self.assertEqual(ret[1], None) - - r1 = store.get("1") - r2 = store.get("4") - self.assertTrue((pickle.loads(r1) == self.generated_data[1]).all()) + ret = store.mget(["1", "2", "5"]) + self.assertTrue((pickle.loads(ret[0]) == self.generated_data[1]).all()) + self.assertTrue((pickle.loads(ret[1]) == self.generated_data[2]).all()) + self.assertEqual(ret[2], None) + + r1 = store.get("4") + r2 = store.get("6") + self.assertTrue((pickle.loads(r1) == self.generated_data[4]).all()) self.assertEqual(r2, None) cnt = store.num_keys() - self.assertEqual(cnt, 3) + self.assertEqual(cnt, 4) store.clear() self.assertEqual(store.num_keys(), 0) self.assertTrue(store.status()) + # try to shut down resources + store.shutdown() + def test_redis_store(self): store = RedisStore(hosts=None, cluster_mode=False, capacity_per_node=10000000) self.check(store) - # try to shut down resources - store.shutdown() - def test_redis_cluster_store(self): - return n = 3 hosts = [] ports = [] @@ -92,7 +60,7 @@ def test_redis_cluster_store(self): port = find_free_port() p = mp.Process( target=start_redis_server_cli, - args=(port, True, 10000000, f"--cluster-config-file nodes{port}.conf"), + args=(port, 10000000, f"--cluster-config-file nodes{port}.conf"), ) p.start() @@ -103,8 +71,6 @@ def test_redis_cluster_store(self): for p in processes: p.join() - create_redis_cluster_cli(hosts=hosts) - store = RedisStore(hosts=hosts, cluster_mode=True, capacity_per_node=10000000) self.check(store) From 241b21e7a01a8ae944978f017e8d2a2f18c317e8 Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Wed, 4 Aug 2021 18:27:33 +0800 Subject: [PATCH 09/63] Update bagua/torch_api/contrib/utils/redis_store.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- bagua/torch_api/contrib/utils/redis_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 319c6531d..d063607cf 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -137,7 +137,7 @@ def start_redis_server_cli(port, capacity, *args): "--maxmemory-policy allkeys-random", # use random eviction by default "--appendonly no", # disable persistence by default '--save ""', - "--protected-mode no" + "--protected-mode no", ] cmd.extend(list(args)) From 14d6d17da8db302c2e67c0fc64493ad585d4570d Mon Sep 17 00:00:00 2001 From: ritaw Date: Wed, 11 Aug 2021 23:25:03 +0800 Subject: [PATCH 10/63] add --- bagua/torch_api/contrib/__init__.py | 4 +- bagua/torch_api/contrib/cache_dataset.py | 35 ++++++--- bagua/torch_api/contrib/cache_loader.py | 62 ++++++++++++---- bagua/torch_api/contrib/utils/hash_func.py | 75 ++++++++++---------- bagua/torch_api/contrib/utils/redis_store.py | 17 +++++ bagua/torch_api/contrib/utils/store.py | 24 ++++++- 6 files changed, 149 insertions(+), 68 deletions(-) diff --git a/bagua/torch_api/contrib/__init__.py b/bagua/torch_api/contrib/__init__.py index c0b42217b..0dfe649d6 100644 --- a/bagua/torch_api/contrib/__init__.py +++ b/bagua/torch_api/contrib/__init__.py @@ -3,5 +3,5 @@ LoadBalancingDistributedSampler, LoadBalancingDistributedBatchSampler, ) -from .cache_loader import CacheLoader -from .cache_dataset import CacheDataset +from .cache_loader import CacheLoader # noqa: F401 +from .cache_dataset import CacheDataset # noqa: F401 diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index a71892c12..8d2de3e4f 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -1,30 +1,43 @@ from torch.utils.data.dataset import Dataset from .cache_loader import CacheLoader +__all__ = ["CacheDataset"] + class CacheDataset(Dataset): + """ + A dataset wrapper which caches `dataset` samples. + + Args: + dataset: Dataset used for caching. + backend(str): The backend to use. Currently "redis" is supported, which means to use :class:`RedisStore`. + key_prefix(str): Prefix of the cache key. Default ``""``. + batch_writes(int): How many key-value pairs written to cache once. Default ``20``. + + Example:: + + >>> from bagua.torch_api.contrib import CacheDataset + >>> cache_dataset = CacheDataset( +... dataset, backend="redis", hosts=None, cluster_mode=False +... ) + >>> dataloader = torch.utils.data.DataLoader(cached_dataset) + + .. note:: + This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. + """ + def __init__( self, dataset: Dataset, backend: str = "redis", - capacity_per_node: int = 100_000_000_000, key_prefix: str = "", - batch_reads: int = 1, batch_writes: int = 20, **kwargs, ): - """ """ self.dataset = dataset - self.cache_loader = CacheLoader( - backend, - capacity_per_node, - key_prefix, - batch_reads, - batch_writes, - **kwargs, - ) + self.cache_loader = CacheLoader(backend, key_prefix, batch_writes, **kwargs,) def __getitem__(self, item): return self.cache_loader.get(item, lambda x: self.dataset[x]) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 9d63cf84b..6fd74baa7 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -1,6 +1,8 @@ import pickle from collections import defaultdict +__all__ = ["CacheLoader"] + def serialize(input): return pickle.dumps(input) @@ -14,42 +16,61 @@ class CacheLoader: def __init__( self, backend: str = "redis", - capacity_per_node: int = 100_000_000_000, key_prefix: str = "", - batch_reads: int = 1, - batch_writes: int = 20, + batch_writes: int = 1, **kwargs, ): - """ """ + """ + A mapping from keys to values. Values are automatically loaded by the cache, and + are stored in the cache until evicted. + + Args: + backend(str): The backend to use. Currently "redis" is supported, which means to use :class:`RedisStore`. + key_prefix(str): Prefix of the cache key. Default ``""``. + batch_writes(int): How many key-value pairs written to cache once. Default ``1``. + + Example:: + >>> # redis server '127.0.0.1:7000' must be alive beforehand + >>> hosts = [{"host": "127.0.0.1", "port": "7000"}] + >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=False) + >>> + >>> loader.get(index, lambda x: items[x]) + """ + self.backend = backend - self.capacity_per_node = capacity_per_node self.key_prefix = key_prefix if backend == "redis": from .utils.redis_store import RedisStore - self.store = RedisStore(capacity_per_node=capacity_per_node, **kwargs) + self.store = RedisStore(**kwargs) else: - raise ValueError( - 'invalid backend, only support "redis" and "lmdb" at present' - ) + raise ValueError('invalid backend, only support "redis" currently') - self.fetcher = BatchFetcher(self.store, batch_reads, batch_writes) + self.fetcher = BatchFetcher(self.store, 1, batch_writes) def get(self, key, load_fn): + """ + Returns the value associated with key in cache, first loading the value by calling `load_fn(key)` if necessary. + """ + cache_key = "{}{}".format(self.key_prefix, key).encode() ret = self.fetcher.read(cache_key) - if ret == None: + if ret is None: ret = load_fn(key) # write to store self.fetcher.write(cache_key, ret) return ret def num_keys(self): + """Returns total number of keys in cache""" + return self.store.num_keys() def cleanup(self): + """Cleanup the resources used.""" + # TODO: cleanup automatically self.store.shutdown() @@ -69,8 +90,13 @@ def __init__(self, store, batch_reads, batch_writes): def read(self, key): self.read_cnt += 1 - ret = self.store.get(key) - self.write_post_read() + try: + ret = self.store.get(key) + except: + ret = None + else: + self.write_post_read() + if ret is not None: return deserialize(ret) return ret @@ -80,10 +106,16 @@ def write(self, key, value): self.write_map[key] = serialize(value) if self.write_cnt % self.batch_writes == 0: - self.store.mset(self.write_map) - self.write_map.clear() + self.flush_write_map() def write_post_read(self): if self.read_cnt % 1000 == 0 and len(self.write_map) > 0: + self.flush_write_map() + + def flush_write_map(self): + try: self.store.mset(self.write_map) + except: + pass + else: self.write_map.clear() diff --git a/bagua/torch_api/contrib/utils/hash_func.py b/bagua/torch_api/contrib/utils/hash_func.py index 4b1fa789b..77b10aeb4 100644 --- a/bagua/torch_api/contrib/utils/hash_func.py +++ b/bagua/torch_api/contrib/utils/hash_func.py @@ -1,50 +1,51 @@ -# reference: https://github.com/lammertb/libcrc/blob/master/src/crc16.c +__all__ = [] +# reference: https://github.com/lammertb/libcrc/blob/master/src/crc16.c # fmt: off table = [ - 0x0000,0x1021,0x2042,0x3063,0x4084,0x50a5,0x60c6,0x70e7, - 0x8108,0x9129,0xa14a,0xb16b,0xc18c,0xd1ad,0xe1ce,0xf1ef, - 0x1231,0x0210,0x3273,0x2252,0x52b5,0x4294,0x72f7,0x62d6, - 0x9339,0x8318,0xb37b,0xa35a,0xd3bd,0xc39c,0xf3ff,0xe3de, - 0x2462,0x3443,0x0420,0x1401,0x64e6,0x74c7,0x44a4,0x5485, - 0xa56a,0xb54b,0x8528,0x9509,0xe5ee,0xf5cf,0xc5ac,0xd58d, - 0x3653,0x2672,0x1611,0x0630,0x76d7,0x66f6,0x5695,0x46b4, - 0xb75b,0xa77a,0x9719,0x8738,0xf7df,0xe7fe,0xd79d,0xc7bc, - 0x48c4,0x58e5,0x6886,0x78a7,0x0840,0x1861,0x2802,0x3823, - 0xc9cc,0xd9ed,0xe98e,0xf9af,0x8948,0x9969,0xa90a,0xb92b, - 0x5af5,0x4ad4,0x7ab7,0x6a96,0x1a71,0x0a50,0x3a33,0x2a12, - 0xdbfd,0xcbdc,0xfbbf,0xeb9e,0x9b79,0x8b58,0xbb3b,0xab1a, - 0x6ca6,0x7c87,0x4ce4,0x5cc5,0x2c22,0x3c03,0x0c60,0x1c41, - 0xedae,0xfd8f,0xcdec,0xddcd,0xad2a,0xbd0b,0x8d68,0x9d49, - 0x7e97,0x6eb6,0x5ed5,0x4ef4,0x3e13,0x2e32,0x1e51,0x0e70, - 0xff9f,0xefbe,0xdfdd,0xcffc,0xbf1b,0xaf3a,0x9f59,0x8f78, - 0x9188,0x81a9,0xb1ca,0xa1eb,0xd10c,0xc12d,0xf14e,0xe16f, - 0x1080,0x00a1,0x30c2,0x20e3,0x5004,0x4025,0x7046,0x6067, - 0x83b9,0x9398,0xa3fb,0xb3da,0xc33d,0xd31c,0xe37f,0xf35e, - 0x02b1,0x1290,0x22f3,0x32d2,0x4235,0x5214,0x6277,0x7256, - 0xb5ea,0xa5cb,0x95a8,0x8589,0xf56e,0xe54f,0xd52c,0xc50d, - 0x34e2,0x24c3,0x14a0,0x0481,0x7466,0x6447,0x5424,0x4405, - 0xa7db,0xb7fa,0x8799,0x97b8,0xe75f,0xf77e,0xc71d,0xd73c, - 0x26d3,0x36f2,0x0691,0x16b0,0x6657,0x7676,0x4615,0x5634, - 0xd94c,0xc96d,0xf90e,0xe92f,0x99c8,0x89e9,0xb98a,0xa9ab, - 0x5844,0x4865,0x7806,0x6827,0x18c0,0x08e1,0x3882,0x28a3, - 0xcb7d,0xdb5c,0xeb3f,0xfb1e,0x8bf9,0x9bd8,0xabbb,0xbb9a, - 0x4a75,0x5a54,0x6a37,0x7a16,0x0af1,0x1ad0,0x2ab3,0x3a92, - 0xfd2e,0xed0f,0xdd6c,0xcd4d,0xbdaa,0xad8b,0x9de8,0x8dc9, - 0x7c26,0x6c07,0x5c64,0x4c45,0x3ca2,0x2c83,0x1ce0,0x0cc1, - 0xef1f,0xff3e,0xcf5d,0xdf7c,0xaf9b,0xbfba,0x8fd9,0x9ff8, - 0x6e17,0x7e36,0x4e55,0x5e74,0x2e93,0x3eb2,0x0ed1,0x1ef0 + 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, + 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, + 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, + 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, + 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, + 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, + 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, + 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, + 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, + 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, + 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, + 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, + 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, + 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, + 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, + 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, + 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, + 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, + 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, + 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, + 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, + 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, + 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, + 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, + 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, + 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, + 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, + 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, + 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, + 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, + 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, + 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0 ] -def crc16(data: bytes): +def crc16(data): + if isinstance(data, str): + data = data.encode() + hash_code = 0x000 for i in data: hash_code = (hash_code >> 8) ^ table[(hash_code ^ i) & 0xFF] return hash_code - -if __name__ == "__main__": - print(crc16(b'abc')) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index d063607cf..74e9f7485 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -9,11 +9,28 @@ import json import logging +__all__ = ["RedisStore"] _host_ip = None class RedisStore(ClusterStore): + """ + A Redis-based store implementation. + + The server holds the data, while the client can connect to the server over Redis protocal and perform + actions such as set() to insert a key-value pair, get() to retrieve a key-value pair, etc. + + Args: + hosts (List[Dict[str, str]]): A list of redis servers, defined by a list of "host" and "port" mappings. Can be ``None``, which + means to bootstrap redis servers locally by Bagua processes. + cluster_mode (bool): View redis servers as a cluster or not. If True, data is automatically sharded across all redis servers, + otherwise, each process connects to and stores data to only one redis server. In bootstrapped cases, each process connects to + its local redis server. + capacity_per_node (int): Maximum memory limit in bytes to configure bootstrapped redis servers. Redis servers will randomly evict + keys when maximum memory limit reached. + """ + def __init__( self, hosts: List[Dict[str, str]] = None, diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 2b10daaf3..24ea3c95d 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -1,9 +1,14 @@ from typing import List, Dict, Optional, Any -from .hash_func import crc16 from collections import defaultdict class Store: + """ + Base class for all store implementations. A store keeps a mapping from keys to values. + key-value pairs are manually added to store using `set()` or `mset()` and can be retrieved by + `get()` or `mget()`. + """ + def set(self, key: str, value: str): pass @@ -30,12 +35,25 @@ def shutdown(self): class ClusterStore(Store): - def __init__(self, stores: List[Store]): + """ + An implementation for a cluster of stores. + + Data is sharded on client side. Default hashing algorithm for the shard key is CRC-16. Can + accept customized hashing algorithms by passing `hash_fn` on initialization. + """ + + def __init__(self, stores: List[Store], hash_fn=None): self.stores = stores self.num_stores = len(stores) + if hash_fn is None: + from .hash_func import crc16 + + hash_fn = crc16 + self.hash_fn = hash_fn + def _hash_key(self, key): - hash_code = crc16(key) + hash_code = self.hash_fn(key) return hash_code % self.num_stores def route(self, key) -> Store: From 4b7a7d69d0b26c2547898e69245ab89ac1e5076d Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Wed, 11 Aug 2021 23:26:25 +0800 Subject: [PATCH 11/63] Update bagua/torch_api/contrib/cache_dataset.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- bagua/torch_api/contrib/cache_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index 8d2de3e4f..f39a35743 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -22,8 +22,8 @@ class CacheDataset(Dataset): ... ) >>> dataloader = torch.utils.data.DataLoader(cached_dataset) - .. note:: - This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. + .. note:: + This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. """ def __init__( From 0bb1d150fc1363feb4cf0ac016051feba0d48f84 Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Wed, 11 Aug 2021 23:26:32 +0800 Subject: [PATCH 12/63] Update bagua/torch_api/contrib/cache_dataset.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- bagua/torch_api/contrib/cache_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index f39a35743..2118918cc 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -16,11 +16,11 @@ class CacheDataset(Dataset): Example:: - >>> from bagua.torch_api.contrib import CacheDataset - >>> cache_dataset = CacheDataset( -... dataset, backend="redis", hosts=None, cluster_mode=False -... ) - >>> dataloader = torch.utils.data.DataLoader(cached_dataset) + >>> from bagua.torch_api.contrib import CacheDataset + >>> cache_dataset = CacheDataset( + ... dataset, backend="redis", hosts=None, cluster_mode=False + ... ) + >>> dataloader = torch.utils.data.DataLoader(cached_dataset) .. note:: This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. From be8dfa86071dcbbe5e9e678abedc9d0ac12c4d5f Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Wed, 11 Aug 2021 23:27:23 +0800 Subject: [PATCH 13/63] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- bagua/torch_api/contrib/cache_dataset.py | 21 ++++++++++++-------- bagua/torch_api/contrib/utils/redis_store.py | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index 2118918cc..5fc5f3bd0 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -6,15 +6,15 @@ class CacheDataset(Dataset): """ - A dataset wrapper which caches `dataset` samples. + A dataset wrapper which caches `dataset` samples. - Args: - dataset: Dataset used for caching. - backend(str): The backend to use. Currently "redis" is supported, which means to use :class:`RedisStore`. - key_prefix(str): Prefix of the cache key. Default ``""``. - batch_writes(int): How many key-value pairs written to cache once. Default ``20``. + Args: + dataset: Dataset used for caching. + backend(str): The backend to use. Currently "redis" is supported, which means to use :class:`RedisStore`. + key_prefix(str): Prefix of the cache key. Default ``""``. + batch_writes(int): How many key-value pairs written to cache once. Default ``20``. - Example:: + Example:: >>> from bagua.torch_api.contrib import CacheDataset >>> cache_dataset = CacheDataset( @@ -37,7 +37,12 @@ def __init__( self.dataset = dataset - self.cache_loader = CacheLoader(backend, key_prefix, batch_writes, **kwargs,) + self.cache_loader = CacheLoader( + backend, + key_prefix, + batch_writes, + **kwargs, + ) def __getitem__(self, item): return self.cache_loader.get(item, lambda x: self.dataset[x]) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 74e9f7485..b5bf2cf11 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -27,7 +27,7 @@ class RedisStore(ClusterStore): cluster_mode (bool): View redis servers as a cluster or not. If True, data is automatically sharded across all redis servers, otherwise, each process connects to and stores data to only one redis server. In bootstrapped cases, each process connects to its local redis server. - capacity_per_node (int): Maximum memory limit in bytes to configure bootstrapped redis servers. Redis servers will randomly evict + capacity_per_node (int): Maximum memory limit in bytes to configure bootstrapped redis servers. Redis servers will randomly evict keys when maximum memory limit reached. """ From a1a9c78908b9ea4c1af551d3ec39c354baeb682b Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Wed, 11 Aug 2021 23:28:02 +0800 Subject: [PATCH 14/63] Update bagua/torch_api/contrib/cache_dataset.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- bagua/torch_api/contrib/cache_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index 5fc5f3bd0..3d7348497 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -22,8 +22,8 @@ class CacheDataset(Dataset): ... ) >>> dataloader = torch.utils.data.DataLoader(cached_dataset) - .. note:: - This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. + .. note:: + This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. """ def __init__( From e864347389a4c94965258930ae90b69e1ed5b787 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 10:56:50 +0800 Subject: [PATCH 15/63] format --- bagua/torch_api/contrib/cache_loader.py | 4 ++-- bagua/torch_api/contrib/utils/redis_store.py | 2 +- bagua/torch_api/contrib/utils/store.py | 14 +++++++------- tests/contrib/test_store.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 6fd74baa7..9cdec87e1 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -92,7 +92,7 @@ def read(self, key): try: ret = self.store.get(key) - except: + except Exception: ret = None else: self.write_post_read() @@ -115,7 +115,7 @@ def write_post_read(self): def flush_write_map(self): try: self.store.mset(self.write_map) - except: + except Exception: pass else: self.write_map.clear() diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index b5bf2cf11..68a8edfbf 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -104,7 +104,7 @@ def _connect_with_retry(self, retry_times=3): for i in range(retry_times): try: connected = self.client.ping() - except Exception as e: + except Exception: if i == retry_times - 1: return False diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 24ea3c95d..08b9c3257 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Optional, Any +from typing import List, Dict, Optional from collections import defaultdict @@ -101,14 +101,14 @@ def mget(self, keys: List[str]) -> List[Optional[str]]: route_table = {} for k in keys: sid = self._hash_key(k) - l = route_table.get(sid, []) - l.append(k) - route_table[sid] = l + ll = route_table.get(sid, []) + ll.append(k) + route_table[sid] = ll result_map = {} - for sid, l in route_table.items(): - ret = self.stores[sid].mget(l) - m = {k: v for k, v in zip(l, ret)} + for sid, ll in route_table.items(): + ret = self.stores[sid].mget(ll) + m = {k: v for k, v in zip(ll, ret)} result_map = {**result_map, **m} return list(map(lambda x: result_map.get(x, None), keys)) diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index 5328aa8f5..d5ba2dbf1 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -9,7 +9,7 @@ import logging import numpy as np import pickle -import time + logging.basicConfig(level=logging.DEBUG) From 7ab867559300955eaaabdae724448b9741ac06fe Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 12:04:33 +0800 Subject: [PATCH 16/63] update doc --- bagua/torch_api/contrib/cache_dataset.py | 36 +++++++++++------------ bagua/torch_api/contrib/cache_loader.py | 15 ++++++---- bagua/torch_api/contrib/utils/__init__.py | 1 + bagua/torch_api/contrib/utils/store.py | 14 +++++++-- tests/contrib/test_store.py | 2 +- 5 files changed, 40 insertions(+), 28 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index 3d7348497..a5772848d 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -5,35 +5,35 @@ class CacheDataset(Dataset): - """ + def __init__( + self, + dataset: Dataset, + backend: str = "redis", + key_prefix: str = "", + batch_writes: int = 20, + **kwargs, + ): + """ A dataset wrapper which caches `dataset` samples. Args: dataset: Dataset used for caching. - backend(str): The backend to use. Currently "redis" is supported, which means to use :class:`RedisStore`. + backend(str): The backend to use. Currently "redis" is supported. key_prefix(str): Prefix of the cache key. Default ``""``. batch_writes(int): How many key-value pairs written to cache once. Default ``20``. Example:: - >>> from bagua.torch_api.contrib import CacheDataset - >>> cache_dataset = CacheDataset( - ... dataset, backend="redis", hosts=None, cluster_mode=False - ... ) - >>> dataloader = torch.utils.data.DataLoader(cached_dataset) + >>> from bagua.torch_api.contrib import CacheDataset + >>> cache_dataset = CacheDataset( + ... dataset, backend="redis", hosts=None, cluster_mode=False + ... ) + >>> dataloader = torch.utils.data.DataLoader(cached_dataset) - .. note:: - This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. - """ + .. note:: - def __init__( - self, - dataset: Dataset, - backend: str = "redis", - key_prefix: str = "", - batch_writes: int = 20, - **kwargs, - ): + This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. + """ self.dataset = dataset diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 9cdec87e1..2dabe02d6 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -25,16 +25,18 @@ def __init__( are stored in the cache until evicted. Args: - backend(str): The backend to use. Currently "redis" is supported, which means to use :class:`RedisStore`. + backend(str): The backend to use. Currently "redis" is supported. key_prefix(str): Prefix of the cache key. Default ``""``. batch_writes(int): How many key-value pairs written to cache once. Default ``1``. Example:: - >>> # redis server '127.0.0.1:7000' must be alive beforehand - >>> hosts = [{"host": "127.0.0.1", "port": "7000"}] - >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=False) + To use an initialized redis clusters: {'192.168.1.0:7000', '192.168.1.1:7000'} + + >>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}] + >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True) >>> >>> loader.get(index, lambda x: items[x]) + """ self.backend = backend @@ -51,7 +53,8 @@ def __init__( def get(self, key, load_fn): """ - Returns the value associated with key in cache, first loading the value by calling `load_fn(key)` if necessary. + Returns the value associated with key in cache, first loading the value if necessary. + `load_fn` accepts `key` as input, and returns an object ser """ cache_key = "{}{}".format(self.key_prefix, key).encode() @@ -64,7 +67,7 @@ def get(self, key, load_fn): return ret def num_keys(self): - """Returns total number of keys in cache""" + """Returns the total number of keys in cache""" return self.store.num_keys() diff --git a/bagua/torch_api/contrib/utils/__init__.py b/bagua/torch_api/contrib/utils/__init__.py index e69de29bb..d81361da5 100644 --- a/bagua/torch_api/contrib/utils/__init__.py +++ b/bagua/torch_api/contrib/utils/__init__.py @@ -0,0 +1 @@ +__all__ = ["redis_store", "store"] diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 08b9c3257..3b6c3d8ce 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -2,6 +2,9 @@ from collections import defaultdict +__all__ = ["Store", "ClusterStore"] + + class Store: """ Base class for all store implementations. A store keeps a mapping from keys to values. @@ -36,13 +39,18 @@ def shutdown(self): class ClusterStore(Store): """ - An implementation for a cluster of stores. + An implementation for a store cluster. - Data is sharded on client side. Default hashing algorithm for the shard key is CRC-16. Can + This class implements client side sharding. It uses CRC-16 algorithm to compute the shard key by default, and can accept customized hashing algorithms by passing `hash_fn` on initialization. + + key-value pairs are manually added to the cluster using `set()` or `mset()` and can be retrieved by + `get()` or `mget()`. + """ def __init__(self, stores: List[Store], hash_fn=None): + self.stores = stores self.num_stores = len(stores) @@ -52,7 +60,7 @@ def __init__(self, stores: List[Store], hash_fn=None): hash_fn = crc16 self.hash_fn = hash_fn - def _hash_key(self, key): + def _hash_key(self, key) -> int: hash_code = self.hash_fn(key) return hash_code % self.num_stores diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index d5ba2dbf1..69f36e31f 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -71,7 +71,7 @@ def test_redis_cluster_store(self): for p in processes: p.join() - store = RedisStore(hosts=hosts, cluster_mode=True, capacity_per_node=10000000) + store = RedisStore(hosts=hosts, cluster_mode=True) self.check(store) self.assertTrue(store.status()) From 53ae81145597255164f8a617c313fe37fd887915 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 12:06:12 +0800 Subject: [PATCH 17/63] . --- bagua/torch_api/contrib/utils/redis_store.py | 23 ++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 68a8edfbf..d90e58bfc 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -18,17 +18,18 @@ class RedisStore(ClusterStore): """ A Redis-based store implementation. - The server holds the data, while the client can connect to the server over Redis protocal and perform - actions such as set() to insert a key-value pair, get() to retrieve a key-value pair, etc. + The server holds the data, while the client can connect to the server over Redis protocol and perform + actions such as `set()` to insert a key-value pair, `get()` to retrieve a key-value pair, etc. Args: - hosts (List[Dict[str, str]]): A list of redis servers, defined by a list of "host" and "port" mappings. Can be ``None``, which - means to bootstrap redis servers locally by Bagua processes. - cluster_mode (bool): View redis servers as a cluster or not. If True, data is automatically sharded across all redis servers, - otherwise, each process connects to and stores data to only one redis server. In bootstrapped cases, each process connects to - its local redis server. - capacity_per_node (int): Maximum memory limit in bytes to configure bootstrapped redis servers. Redis servers will randomly evict - keys when maximum memory limit reached. + hosts (List[Dict[str, str]]): A list of redis servers. Can be ``None``, which means to bootstrap redis servers + locally. + cluster_mode (bool): Redis servers serve as a cluster or not. If True, data is automatically sharded across all + redis servers. + capacity_per_node (int): Maximum memory limit in bytes to configure redis servers. Useful only for local bootstrapped + redis servers. + hash_fn: Hash function to compute the shard key. Default is `crc16`. A `hash_fn` accepts a `str` or `bytes` as + input, and returns an `int` as output. """ def __init__( @@ -36,8 +37,8 @@ def __init__( hosts: List[Dict[str, str]] = None, cluster_mode: bool = False, capacity_per_node: int = 100_000_000_000, + hash_fn=None, ): - """ """ self.hosts = [] if hosts is None: @@ -61,7 +62,7 @@ def __init__( ) stores.append(store) - super(RedisStore, self).__init__(stores) + super(RedisStore, self).__init__(stores, hash_fn) def _bootstrap_redis_server(self): ip, port = get_host_ip(), find_free_port() From b53a18653844714ca355237f7a90de56b321ce4c Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 14:09:54 +0800 Subject: [PATCH 18/63] . --- bagua/torch_api/contrib/utils/redis_store.py | 2 +- bagua/torch_api/contrib/utils/store.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index d90e58bfc..f98caff2d 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -124,7 +124,7 @@ def get(self, key: str) -> Optional[str]: def num_keys(self) -> int: return self.client.dbsize() - def clear(self) -> bool: + def clear(self): self.client.flushdb() def mset(self, mapping: Dict[str, str]): diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 3b6c3d8ce..00117c092 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -21,7 +21,7 @@ def get(self, key: str) -> Optional[str]: def num_keys(self) -> int: pass - def clear(self) -> bool: + def clear(self): pass def mset(self, mapping: Dict[str, str]): @@ -84,7 +84,7 @@ def get(self, key: str) -> Optional[str]: def num_keys(self) -> int: return sum([store.num_keys() for store in self.stores]) - def clear(self) -> bool: + def clear(self): for store in self.stores: store.clear() From bb4acad4dc0570972a57cd4501c16ac032c4a1dc Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 15:22:25 +0800 Subject: [PATCH 19/63] refine docs --- bagua/torch_api/contrib/cache_dataset.py | 14 ++++++++------ bagua/torch_api/contrib/cache_loader.py | 11 ++++++++--- bagua/torch_api/contrib/utils/hash_func.py | 2 -- bagua/torch_api/contrib/utils/redis_store.py | 11 ++++++----- bagua/torch_api/contrib/utils/store.py | 15 +++++++++++++++ 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index a5772848d..b75dfa6cf 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -16,6 +16,8 @@ def __init__( """ A dataset wrapper which caches `dataset` samples. + This is useful in scenarios when `dataset` has a lot preprocessing work to fetch a sample. + Args: dataset: Dataset used for caching. backend(str): The backend to use. Currently "redis" is supported. @@ -33,16 +35,16 @@ def __init__( .. note:: This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. + + .. note:: + The cache assocaite dataset indices to determined dataset samples, thus it will violate the randomness of the dataset. + Use :class:`CacheLoader` which can wrap arbitrary data loading logic in this situation. + """ self.dataset = dataset - self.cache_loader = CacheLoader( - backend, - key_prefix, - batch_writes, - **kwargs, - ) + self.cache_loader = CacheLoader(backend, key_prefix, batch_writes, **kwargs,) def __getitem__(self, item): return self.cache_loader.get(item, lambda x: self.dataset[x]) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 2dabe02d6..bcfdefae1 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -25,18 +25,23 @@ def __init__( are stored in the cache until evicted. Args: - backend(str): The backend to use. Currently "redis" is supported. + backend(str): The backend to use. Currently "redis" is supported. If using "redis" backend, must provide + argument `hosts` to initialize :class:`RedisStore`. See :class:`RedisStore` for more information. key_prefix(str): Prefix of the cache key. Default ``""``. - batch_writes(int): How many key-value pairs written to cache once. Default ``1``. + batch_writes(int): How many key-value pairs written to cache once. Default ``1``. If `batch_writes > 1`, the cache + will combine multiple `set` operations to one or a few `mset` operations. May help to reduce the write latency. Example:: - To use an initialized redis clusters: {'192.168.1.0:7000', '192.168.1.1:7000'} + To use "redis" backend and initialized redis clusters: `{'192.168.1.0:7000', '192.168.1.1:7000'}` >>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}] >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True) >>> >>> loader.get(index, lambda x: items[x]) + To use "redis" backend and bootstrap redis servers locally: + + >>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000) """ self.backend = backend diff --git a/bagua/torch_api/contrib/utils/hash_func.py b/bagua/torch_api/contrib/utils/hash_func.py index 77b10aeb4..88f1119b8 100644 --- a/bagua/torch_api/contrib/utils/hash_func.py +++ b/bagua/torch_api/contrib/utils/hash_func.py @@ -1,5 +1,3 @@ -__all__ = [] - # reference: https://github.com/lammertb/libcrc/blob/master/src/crc16.c # fmt: off diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index f98caff2d..6930aa24f 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -22,14 +22,15 @@ class RedisStore(ClusterStore): actions such as `set()` to insert a key-value pair, `get()` to retrieve a key-value pair, etc. Args: - hosts (List[Dict[str, str]]): A list of redis servers. Can be ``None``, which means to bootstrap redis servers - locally. + hosts (List[Dict[str, str]]): A list of redis servers, defined by a list of dict containing server host and + port information. Can be ``None``, which means to bootstrap redis servers locally. cluster_mode (bool): Redis servers serve as a cluster or not. If True, data is automatically sharded across all - redis servers. - capacity_per_node (int): Maximum memory limit in bytes to configure redis servers. Useful only for local bootstrapped - redis servers. + redis servers, otherwise, data is routed to a specific server. + capacity_per_node (int): Maximum memory limit in bytes to configure redis servers when bootstrap locally. Redis servers + will evict keys randomly when maximum memory limit is reached. hash_fn: Hash function to compute the shard key. Default is `crc16`. A `hash_fn` accepts a `str` or `bytes` as input, and returns an `int` as output. + """ def __init__( diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 00117c092..36324784b 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -13,27 +13,37 @@ class Store: """ def set(self, key: str, value: str): + "Set a key-value pair." pass def get(self, key: str) -> Optional[str]: + "Return the value associated with key `key`, or None if the key doesn’t exist" pass def num_keys(self) -> int: + "Returns the number of keys in the current store." pass def clear(self): + "Delete all keys in the current store." pass def mset(self, mapping: Dict[str, str]): + "Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings." pass def mget(self, keys: List[str]) -> List[Optional[str]]: + "Returns a list of values ordered identically to `keys`." pass def status(self) -> bool: + "Check the status of the current store." pass def shutdown(self): + """ + Shutdown the current store. External store resources, for example, initialized redis servers, will not be shutted down by this method. + """ pass @@ -47,6 +57,11 @@ class ClusterStore(Store): key-value pairs are manually added to the cluster using `set()` or `mset()` and can be retrieved by `get()` or `mget()`. + Args: + stores(List[Store]): A list of stores in the cluster. + hash_fn: Hash function to compute the shard key. Default is `crc16`. A `hash_fn` accepts a `str` or `bytes` as + input, and returns an `int` as output. + """ def __init__(self, stores: List[Store], hash_fn=None): From 88118889e956693eeec5f5e7b98a74031048dc0b Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 15:34:01 +0800 Subject: [PATCH 20/63] add --- bagua/torch_api/contrib/utils/store.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 36324784b..d7542248f 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -16,11 +16,11 @@ def set(self, key: str, value: str): "Set a key-value pair." pass - def get(self, key: str) -> Optional[str]: + def get(self, key): "Return the value associated with key `key`, or None if the key doesn’t exist" pass - def num_keys(self) -> int: + def num_keys(self): "Returns the number of keys in the current store." pass @@ -28,15 +28,15 @@ def clear(self): "Delete all keys in the current store." pass - def mset(self, mapping: Dict[str, str]): + def mset(self, mapping): "Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings." pass - def mget(self, keys: List[str]) -> List[Optional[str]]: + def mget(self, keys): "Returns a list of values ordered identically to `keys`." pass - def status(self) -> bool: + def status(self): "Check the status of the current store." pass From d90c0983d66a532153dc79cc772e6310ce9377ec Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 17:03:37 +0800 Subject: [PATCH 21/63] update doc --- bagua/torch_api/contrib/cache_dataset.py | 10 +++++++--- bagua/torch_api/contrib/cache_loader.py | 7 ++++--- docs/conf.py | 13 +++++++++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index b75dfa6cf..2323e6752 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -20,9 +20,12 @@ def __init__( Args: dataset: Dataset used for caching. - backend(str): The backend to use. Currently "redis" is supported. + backend(str): The backend to use. Currently ``"redis"`` is supported. If using ``"redis"`` backend, must provide + argument `hosts` to initialize :class:`RedisStore`. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` + for further customization. key_prefix(str): Prefix of the cache key. Default ``""``. - batch_writes(int): How many key-value pairs written to cache once. Default ``20``. + batch_writes(int): How many key-value pairs written to cache once. Default ``20``. If `batch_writes > 1`, the cache + will combine multiple `set` operations to one or a few `mset` operations. May help to reduce the write latency. Example:: @@ -34,7 +37,8 @@ def __init__( .. note:: - This class use :class:`CacheLoader` as the implementation of cache. See :class:`CacheLoader` for more information. + This class use :class:`CacheLoader` as the implementation of cache. See + :class:`bagua.torch_api.contrib.CacheLoader` for more information. .. note:: The cache assocaite dataset indices to determined dataset samples, thus it will violate the randomness of the dataset. diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index bcfdefae1..d97a839b6 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -25,14 +25,15 @@ def __init__( are stored in the cache until evicted. Args: - backend(str): The backend to use. Currently "redis" is supported. If using "redis" backend, must provide - argument `hosts` to initialize :class:`RedisStore`. See :class:`RedisStore` for more information. + backend(str): The backend to use. Currently ``"redis"`` is supported. If using ``"redis"`` backend, must provide + argument `hosts` to initialize :class:`RedisStore`. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` + for further customization. key_prefix(str): Prefix of the cache key. Default ``""``. batch_writes(int): How many key-value pairs written to cache once. Default ``1``. If `batch_writes > 1`, the cache will combine multiple `set` operations to one or a few `mset` operations. May help to reduce the write latency. Example:: - To use "redis" backend and initialized redis clusters: `{'192.168.1.0:7000', '192.168.1.1:7000'}` + To use "redis" backend and initialized redis clusters: `{'192.168.1.0:7000', '192.168.1.1:7000'}`: >>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}] >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True) diff --git a/docs/conf.py b/docs/conf.py index d72933dd5..cd545f56d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ # -- Project information ----------------------------------------------------- - +import re project = "Bagua API Documentation" copyright = "2021, Kuaishou AI Platform and DS3 Lab" author = "Kuaishou AI Platform and DS3 Lab" @@ -55,6 +55,7 @@ "*/bagua/torch_api/globals.py", "*/bagua/version.py", "*/bagua/bagua_define.py", + "*/bagua/torch_api/contrib/utils/hash_func.py" ] autoapi_options = [ "members", @@ -127,6 +128,7 @@ "bagua.torch_api.contrib.LoadBalancingDistributedBatchSampler.generate_batches", "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedSampler.shuffle_chunks", "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedBatchSampler.generate_batches", + "bagua.torch_api.contrib.utils.store.ClusterStore.*" ] _ignore_functions = [ "bagua.torch_api.env.get_autotune_server_addr", @@ -160,9 +162,12 @@ def skip_methods(app, what, name, obj, skip, options): - if what == "method" and name in _ignore_methods: - skip = True - return skip + if what == "method": + for to_ignore in _ignore_methods: + p = re.compile(to_ignore) + if p.match(name): + skip = True + return skip if what == "function" and name in _ignore_functions: skip = True From ebdd5200beb43913e269428754ef9ffcddfb22e5 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 17:34:27 +0800 Subject: [PATCH 22/63] add import error --- bagua/torch_api/contrib/utils/redis_store.py | 23 +++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 6930aa24f..c7fb2b7ed 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -2,13 +2,34 @@ import subprocess import time from bagua.torch_api.env import get_rank, get_local_rank, get_world_size, get_local_size -from redis import Redis + +try: + from redis import Redis +except ImportError as err: + print( + "DEBUG: did not find redis-py. To install it, run `pip install redis` or follow instructions on its website(https://github.com/andymccurdy/redis-py)." + ) + raise err + from typing import List, Dict, Optional from .store import Store, ClusterStore import torch.distributed.distributed_c10d as c10d import json import logging +try: + p = subprocess.Popen( + ["redis-server" "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + out, err = p.communicate() +except Exception: + print( + "DEBUG: did not find redis-server. Follow instructions on its website(https://redis.io/download) to have it installed." + ) + print("DEBUG: out: " + out) + print("DEBUG: err: " + err) + + __all__ = ["RedisStore"] _host_ip = None From 505f55b72c9cd863ddc75fc65dc0ee44b75ea68d Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Thu, 12 Aug 2021 17:36:24 +0800 Subject: [PATCH 23/63] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/conf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index cd545f56d..7d5d102b9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,6 +17,7 @@ # -- Project information ----------------------------------------------------- import re + project = "Bagua API Documentation" copyright = "2021, Kuaishou AI Platform and DS3 Lab" author = "Kuaishou AI Platform and DS3 Lab" @@ -55,7 +56,7 @@ "*/bagua/torch_api/globals.py", "*/bagua/version.py", "*/bagua/bagua_define.py", - "*/bagua/torch_api/contrib/utils/hash_func.py" + "*/bagua/torch_api/contrib/utils/hash_func.py", ] autoapi_options = [ "members", @@ -128,7 +129,7 @@ "bagua.torch_api.contrib.LoadBalancingDistributedBatchSampler.generate_batches", "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedSampler.shuffle_chunks", "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedBatchSampler.generate_batches", - "bagua.torch_api.contrib.utils.store.ClusterStore.*" + "bagua.torch_api.contrib.utils.store.ClusterStore.*", ] _ignore_functions = [ "bagua.torch_api.env.get_autotune_server_addr", From 50c4ea727f503e8854b352405a4c07a2654e9793 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 18:12:55 +0800 Subject: [PATCH 24/63] add package check --- bagua/torch_api/contrib/utils/redis_store.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index c7fb2b7ed..aef6ce2ae 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -5,11 +5,11 @@ try: from redis import Redis -except ImportError as err: +except ImportError: print( "DEBUG: did not find redis-py. To install it, run `pip install redis` or follow instructions on its website(https://github.com/andymccurdy/redis-py)." ) - raise err + raise from typing import List, Dict, Optional from .store import Store, ClusterStore @@ -17,17 +17,16 @@ import json import logging + try: p = subprocess.Popen( - ["redis-server" "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ["redis-server", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - out, err = p.communicate() except Exception: print( "DEBUG: did not find redis-server. Follow instructions on its website(https://redis.io/download) to have it installed." ) - print("DEBUG: out: " + out) - print("DEBUG: err: " + err) + raise __all__ = ["RedisStore"] From 401e3d4cdc18b56878f91c60bf3318fa71af4872 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 19:00:05 +0800 Subject: [PATCH 25/63] update --- bagua/torch_api/contrib/cache_dataset.py | 29 ++++++++++++++---------- bagua/torch_api/contrib/cache_loader.py | 14 ++++++++---- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index 2323e6752..2f6a21e9d 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -16,16 +16,16 @@ def __init__( """ A dataset wrapper which caches `dataset` samples. - This is useful in scenarios when `dataset` has a lot preprocessing work to fetch a sample. + This is useful in scenarios when `dataset` has a lot pre-processing work to fetch a sample. Args: dataset: Dataset used for caching. - backend(str): The backend to use. Currently ``"redis"`` is supported. If using ``"redis"`` backend, must provide - argument `hosts` to initialize :class:`RedisStore`. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` - for further customization. + backend(str): The backend to use. Currently ``"redis"`` is supported. key_prefix(str): Prefix of the cache key. Default ``""``. - batch_writes(int): How many key-value pairs written to cache once. Default ``20``. If `batch_writes > 1`, the cache - will combine multiple `set` operations to one or a few `mset` operations. May help to reduce the write latency. + batch_writes(int): How many key-value pairs written to cache once. Default ``20``, If `batch_writes > 1`, the + cache will delay writing non-existed key-value pairs until `batch_writes` key-value pairs are accumulated. + Thus it could combine multiple `set` operations to one `mset` operation. This is expected to reduce + the write latency. Example:: @@ -36,19 +36,24 @@ def __init__( >>> dataloader = torch.utils.data.DataLoader(cached_dataset) .. note:: - - This class use :class:`CacheLoader` as the implementation of cache. See - :class:`bagua.torch_api.contrib.CacheLoader` for more information. + `CacheDataset` is a special use case of `CacheLoader`, and parameter `backend`, `key_prefix` and `batch_writes` + in `CacheDataset` have the same meanings in `CacheLoader`. See :class:`bagua.torch_api.contrib.CacheLoader` + for more information. .. note:: - The cache assocaite dataset indices to determined dataset samples, thus it will violate the randomness of the dataset. - Use :class:`CacheLoader` which can wrap arbitrary data loading logic in this situation. + The cache associates dataset indices to determined dataset samples, thus it will break the randomness of + the dataset, if it has. Use :class:`CacheLoader` which can wrap arbitrary data loading logic in this situation. """ self.dataset = dataset - self.cache_loader = CacheLoader(backend, key_prefix, batch_writes, **kwargs,) + self.cache_loader = CacheLoader( + backend, + key_prefix, + batch_writes, + **kwargs, + ) def __getitem__(self, item): return self.cache_loader.get(item, lambda x: self.dataset[x]) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index d97a839b6..0bcb1c2fd 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -24,13 +24,17 @@ def __init__( A mapping from keys to values. Values are automatically loaded by the cache, and are stored in the cache until evicted. + Current backend is "redis". Using "redis" backend, the cache will initialize an instance of :class:`RedisStore` + by a list of initialized redis servers or bootstrap redis servers locally. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` + for more information. + Args: - backend(str): The backend to use. Currently ``"redis"`` is supported. If using ``"redis"`` backend, must provide - argument `hosts` to initialize :class:`RedisStore`. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` - for further customization. + backend(str): The backend to use. Currently ``"redis"`` is supported. key_prefix(str): Prefix of the cache key. Default ``""``. - batch_writes(int): How many key-value pairs written to cache once. Default ``1``. If `batch_writes > 1`, the cache - will combine multiple `set` operations to one or a few `mset` operations. May help to reduce the write latency. + batch_writes(int): How many key-value pairs written to cache once. Default ``1``. If `batch_writes > 1`, the + cache will delay writing non-existed key-value pairs until `batch_writes` key-value pairs are accumulated. + Thus it could combine multiple `set` operations to one `mset` operation. This is expected to reduce + the write latency. Example:: To use "redis" backend and initialized redis clusters: `{'192.168.1.0:7000', '192.168.1.1:7000'}`: From 471bb570fc4106515791816336f8002c594e4769 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 19:05:05 +0800 Subject: [PATCH 26/63] update test --- tests/contrib/test_cached_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 74ab263df..83df1f53a 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -7,7 +7,7 @@ logging.basicConfig(level=logging.DEBUG) -class TestDataset(Dataset): +class MyDataset(Dataset): def __init__(self, size): self.size = size self.dataset = [(np.random.rand(5, 2), np.random.rand(1)) for _ in range(size)] @@ -31,7 +31,7 @@ def check_dataset(self, dataset, cache_dataset): self.assertTrue((dataset[i][1] == cache_dataset[i][1]).all()) def test_redis(self): - dataset = TestDataset(102) + dataset = MyDataset(102) cache_dataset = CacheDataset( dataset, backend="redis", hosts=None, cluster_mode=False ) From a95c2609540785673f4af8bcf6168978c5e67559 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 19:20:49 +0800 Subject: [PATCH 27/63] fix type --- bagua/torch_api/contrib/utils/store.py | 35 ++++++++++++++++---------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index d7542248f..9a53e4144 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -12,37 +12,46 @@ class Store: `get()` or `mget()`. """ - def set(self, key: str, value: str): - "Set a key-value pair." + def set(self, key, value): + """Set a key-value pair.""" pass - def get(self, key): - "Return the value associated with key `key`, or None if the key doesn’t exist" + def get(self, key) -> Optional[str]: + """Returns the value associated with key `key`, or None if the key doesn't exist.""" pass - def num_keys(self): - "Returns the number of keys in the current store." + def num_keys(self) -> int: + """Returns the number of keys in the current store.""" pass def clear(self): - "Delete all keys in the current store." + """Delete all keys in the current store.""" pass def mset(self, mapping): - "Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings." + """ + Set key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values + should be strings. + """ pass - def mget(self, keys): - "Returns a list of values ordered identically to `keys`." + def mget(self, keys) -> List[Optional[str]]: + """ + Returns a list of values ordered identically to `keys`. + """ + pass - def status(self): - "Check the status of the current store." + def status(self) -> bool: + """ + Returns the status of the current store. + """ pass def shutdown(self): """ - Shutdown the current store. External store resources, for example, initialized redis servers, will not be shutted down by this method. + Shutdown the current store. External store resources, for example, initialized redis servers, + will not be shutted down by this method. """ pass From 1c2d4c926397eef463801d491538a2b25dbb29cb Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 19:43:17 +0800 Subject: [PATCH 28/63] add --- bagua/torch_api/contrib/cache_loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 0bcb1c2fd..d4be4834e 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -25,12 +25,12 @@ def __init__( are stored in the cache until evicted. Current backend is "redis". Using "redis" backend, the cache will initialize an instance of :class:`RedisStore` - by a list of initialized redis servers or bootstrap redis servers locally. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` + by a list of initialized redis servers or bootstrapping redis servers locally. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` for more information. Args: backend(str): The backend to use. Currently ``"redis"`` is supported. - key_prefix(str): Prefix of the cache key. Default ``""``. + key_prefix(str): Prefix added to the cache key. Default ``""``. batch_writes(int): How many key-value pairs written to cache once. Default ``1``. If `batch_writes > 1`, the cache will delay writing non-existed key-value pairs until `batch_writes` key-value pairs are accumulated. Thus it could combine multiple `set` operations to one `mset` operation. This is expected to reduce @@ -47,6 +47,9 @@ def __init__( To use "redis" backend and bootstrap redis servers locally: >>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000) + + .. note:: + Setting a specific `key_prefix` can be useful to avoid overwriting existing cache data. """ self.backend = backend From 179d939fcda3e8303f1b743acbba5ab7993e2c22 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 19:46:01 +0800 Subject: [PATCH 29/63] . --- bagua/torch_api/contrib/utils/store.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 9a53e4144..f492aff63 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -16,11 +16,11 @@ def set(self, key, value): """Set a key-value pair.""" pass - def get(self, key) -> Optional[str]: + def get(self, key) -> Optional[str]: # type: ignore """Returns the value associated with key `key`, or None if the key doesn't exist.""" pass - def num_keys(self) -> int: + def num_keys(self) -> int: # type: ignore """Returns the number of keys in the current store.""" pass @@ -35,14 +35,14 @@ def mset(self, mapping): """ pass - def mget(self, keys) -> List[Optional[str]]: + def mget(self, keys) -> List[Optional[str]]: # type: ignore """ Returns a list of values ordered identically to `keys`. """ pass - def status(self) -> bool: + def status(self) -> bool: # type: ignore """ Returns the status of the current store. """ From 5ddef17bbcf10b17b586c51fc31dee040df30093 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 21:21:25 +0800 Subject: [PATCH 30/63] pytype --- bagua/torch_api/contrib/utils/store.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index f492aff63..9233f80b5 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -16,13 +16,13 @@ def set(self, key, value): """Set a key-value pair.""" pass - def get(self, key) -> Optional[str]: # type: ignore + def get(self, key) -> Optional[str]: """Returns the value associated with key `key`, or None if the key doesn't exist.""" - pass + pass # type: ignore - def num_keys(self) -> int: # type: ignore + def num_keys(self) -> int: """Returns the number of keys in the current store.""" - pass + pass # type: ignore def clear(self): """Delete all keys in the current store.""" @@ -35,18 +35,18 @@ def mset(self, mapping): """ pass - def mget(self, keys) -> List[Optional[str]]: # type: ignore + def mget(self, keys) -> List[Optional[str]]: """ Returns a list of values ordered identically to `keys`. """ - pass + pass # type: ignore - def status(self) -> bool: # type: ignore + def status(self) -> bool: """ Returns the status of the current store. """ - pass + pass # type: ignore def shutdown(self): """ From 62a57531831caaf7c8fd0305e2b5733afe1c3734 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 21:45:57 +0800 Subject: [PATCH 31/63] . --- bagua/torch_api/contrib/cache_loader.py | 5 +++-- bagua/torch_api/contrib/utils/redis_store.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index d4be4834e..7bd7eed08 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -37,10 +37,11 @@ def __init__( the write latency. Example:: - To use "redis" backend and initialized redis clusters: `{'192.168.1.0:7000', '192.168.1.1:7000'}`: + To use "redis" backend and initialized redis clusters: + >>> from bagua.torch_api.contrib import CacheLoader >>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}] - >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True) + >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, key_prefix="test") >>> >>> loader.get(index, lambda x: items[x]) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index aef6ce2ae..b124ca0db 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -55,7 +55,7 @@ class RedisStore(ClusterStore): def __init__( self, - hosts: List[Dict[str, str]] = None, + hosts: Optional[List[Dict[str, str]]] = None, cluster_mode: bool = False, capacity_per_node: int = 100_000_000_000, hash_fn=None, @@ -159,7 +159,7 @@ def status(self) -> bool: def shutdown(self): if self.bootstrap: - self.client.shutdown(nosave=True) + self.client.shutdown(nosave=True) # pytype: disable=wrong-keyword-args def create_redis_client(host, port): From 06ba1f57b6447ea1ae7ec6f4027343ec122491ea Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 22:04:12 +0800 Subject: [PATCH 32/63] auto cleanup --- bagua/torch_api/contrib/cache_loader.py | 9 ++++----- docs/conf.py | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 7bd7eed08..bd8349536 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -1,5 +1,6 @@ import pickle from collections import defaultdict +import atexit __all__ = ["CacheLoader"] @@ -64,6 +65,7 @@ def __init__( raise ValueError('invalid backend, only support "redis" currently') self.fetcher = BatchFetcher(self.store, 1, batch_writes) + self.register_shutdown_handler() def get(self, key, load_fn): """ @@ -85,11 +87,8 @@ def num_keys(self): return self.store.num_keys() - def cleanup(self): - """Cleanup the resources used.""" - - # TODO: cleanup automatically - self.store.shutdown() + def register_shutdown_handler(self): + atexit.register(self.store.shutdown) class BatchFetcher: diff --git a/docs/conf.py b/docs/conf.py index 7d5d102b9..99b506138 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -130,6 +130,7 @@ "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedSampler.shuffle_chunks", "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedBatchSampler.generate_batches", "bagua.torch_api.contrib.utils.store.ClusterStore.*", + "bagua.torch_api.contrib.cache_loader.CacheLoader.register_shutdown_handler", ] _ignore_functions = [ "bagua.torch_api.env.get_autotune_server_addr", From 75c9f2c70f7d8aa788ef3029a8969acc05a550fb Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 22:23:30 +0800 Subject: [PATCH 33/63] pytype --- bagua/torch_api/contrib/utils/hash_func.py | 5 ++++- bagua/torch_api/contrib/utils/redis_store.py | 10 +++++----- bagua/torch_api/contrib/utils/store.py | 14 +++++++------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/bagua/torch_api/contrib/utils/hash_func.py b/bagua/torch_api/contrib/utils/hash_func.py index 88f1119b8..188a454c5 100644 --- a/bagua/torch_api/contrib/utils/hash_func.py +++ b/bagua/torch_api/contrib/utils/hash_func.py @@ -1,5 +1,8 @@ # reference: https://github.com/lammertb/libcrc/blob/master/src/crc16.c +from typing import Union + + # fmt: off table = [ 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, @@ -37,7 +40,7 @@ ] -def crc16(data): +def crc16(data: Union[str, bytes]): if isinstance(data, str): data = data.encode() diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index b124ca0db..307546462 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -11,7 +11,7 @@ ) raise -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from .store import Store, ClusterStore import torch.distributed.distributed_c10d as c10d import json @@ -136,10 +136,10 @@ def _connect_with_retry(self, retry_times=3): return False - def set(self, key: str, value: str): + def set(self, key: str, value: Union[str, bytes]): self.client.set(key, value) - def get(self, key: str) -> Optional[str]: + def get(self, key: str) -> Optional[Union[str, bytes]]: return self.client.get(key) def num_keys(self) -> int: @@ -148,10 +148,10 @@ def num_keys(self) -> int: def clear(self): self.client.flushdb() - def mset(self, mapping: Dict[str, str]): + def mset(self, mapping: Dict[str, Union[str, bytes]]): self.client.mset(mapping) - def mget(self, keys: List[str]) -> List[Optional[str]]: + def mget(self, keys: List[str]) -> List[Optional[Union[str, bytes]]]: return self.client.mget(keys) def status(self) -> bool: diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 9233f80b5..49ff8c870 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from collections import defaultdict @@ -16,7 +16,7 @@ def set(self, key, value): """Set a key-value pair.""" pass - def get(self, key) -> Optional[str]: + def get(self, key) -> Optional[Union[str, bytes]]: """Returns the value associated with key `key`, or None if the key doesn't exist.""" pass # type: ignore @@ -35,7 +35,7 @@ def mset(self, mapping): """ pass - def mget(self, keys) -> List[Optional[str]]: + def mget(self, keys) -> List[Optional[Union[str, bytes]]]: """ Returns a list of values ordered identically to `keys`. """ @@ -93,13 +93,13 @@ def route(self, key) -> Store: self.stores[self._hash_key(key)] if self.num_stores > 1 else self.stores[0] ) - def set(self, key: str, value: str): + def set(self, key: str, value: Union[str, bytes]): if self.num_stores == 1: return self.stores[0].set(key, value) self.route(key).set(key, value) - def get(self, key: str) -> Optional[str]: + def get(self, key: str) -> Optional[Union[str, bytes]]: if self.num_stores == 1: return self.stores[0].get(key) @@ -112,7 +112,7 @@ def clear(self): for store in self.stores: store.clear() - def mset(self, mapping: Dict[str, str]): + def mset(self, mapping: Dict[str, Union[str, bytes]]): if self.num_stores == 1: return self.stores[0].mset(mapping) @@ -126,7 +126,7 @@ def mset(self, mapping: Dict[str, str]): for sid, m in route_table.items(): self.stores[sid].mset(m) - def mget(self, keys: List[str]) -> List[Optional[str]]: + def mget(self, keys: List[str]) -> List[Optional[Union[str, bytes]]]: if self.num_stores == 1: return self.stores[0].mget(keys) From eb8fde5bb15a3334043c63754f00312f1ef4798d Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 22:38:29 +0800 Subject: [PATCH 34/63] . --- bagua/torch_api/contrib/cache_loader.py | 2 +- docs/conf.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index bd8349536..0dea275f7 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -73,7 +73,7 @@ def get(self, key, load_fn): `load_fn` accepts `key` as input, and returns an object ser """ - cache_key = "{}{}".format(self.key_prefix, key).encode() + cache_key = "{}{}".format(self.key_prefix, key) ret = self.fetcher.read(cache_key) if ret is None: diff --git a/docs/conf.py b/docs/conf.py index 99b506138..7ecaf1319 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -125,12 +125,10 @@ _ignore_methods = [ - "bagua.torch_api.contrib.LoadBalancingDistributedSampler.shuffle_chunks", - "bagua.torch_api.contrib.LoadBalancingDistributedBatchSampler.generate_batches", - "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedSampler.shuffle_chunks", - "bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedBatchSampler.generate_batches", + "bagua.torch_api.contrib.*LoadBalancingDistributedSampler.shuffle_chunks", + "bagua.torch_api.contrib.*LoadBalancingDistributedBatchSampler.generate_batches", "bagua.torch_api.contrib.utils.store.ClusterStore.*", - "bagua.torch_api.contrib.cache_loader.CacheLoader.register_shutdown_handler", + "bagua.torch_api.contrib.*CacheLoader.register_shutdown_handler", ] _ignore_functions = [ "bagua.torch_api.env.get_autotune_server_addr", From 45bf719eaf10d48f890907e7a322a44a2d0d72b2 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 12 Aug 2021 22:57:06 +0800 Subject: [PATCH 35/63] doc doc --- bagua/torch_api/contrib/cache_dataset.py | 6 +++--- bagua/torch_api/contrib/cache_loader.py | 4 ++-- bagua/torch_api/contrib/utils/store.py | 5 ++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py index 2f6a21e9d..6b2540795 100644 --- a/bagua/torch_api/contrib/cache_dataset.py +++ b/bagua/torch_api/contrib/cache_dataset.py @@ -16,7 +16,7 @@ def __init__( """ A dataset wrapper which caches `dataset` samples. - This is useful in scenarios when `dataset` has a lot pre-processing work to fetch a sample. + This is useful in scenarios when `dataset` is slow to fetch a sample. Args: dataset: Dataset used for caching. @@ -24,7 +24,7 @@ def __init__( key_prefix(str): Prefix of the cache key. Default ``""``. batch_writes(int): How many key-value pairs written to cache once. Default ``20``, If `batch_writes > 1`, the cache will delay writing non-existed key-value pairs until `batch_writes` key-value pairs are accumulated. - Thus it could combine multiple `set` operations to one `mset` operation. This is expected to reduce + Thus it could combine multiple `set` operations into one `mset` operation, and is expected to reduce the write latency. Example:: @@ -37,7 +37,7 @@ def __init__( .. note:: `CacheDataset` is a special use case of `CacheLoader`, and parameter `backend`, `key_prefix` and `batch_writes` - in `CacheDataset` have the same meanings in `CacheLoader`. See :class:`bagua.torch_api.contrib.CacheLoader` + in `CacheDataset` have the same meanings with those in `CacheLoader`. See :class:`bagua.torch_api.contrib.CacheLoader` for more information. .. note:: diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 0dea275f7..44eccbe7d 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -34,7 +34,7 @@ def __init__( key_prefix(str): Prefix added to the cache key. Default ``""``. batch_writes(int): How many key-value pairs written to cache once. Default ``1``. If `batch_writes > 1`, the cache will delay writing non-existed key-value pairs until `batch_writes` key-value pairs are accumulated. - Thus it could combine multiple `set` operations to one `mset` operation. This is expected to reduce + Thus it could combine multiple `set` operations into one `mset` operation, and is expected to reduce the write latency. Example:: @@ -70,7 +70,7 @@ def __init__( def get(self, key, load_fn): """ Returns the value associated with key in cache, first loading the value if necessary. - `load_fn` accepts `key` as input, and returns an object ser + `load_fn` accepts `key` as input, and returns the data to be serialized and stored. """ cache_key = "{}{}".format(self.key_prefix, key) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 49ff8c870..18660c37b 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -30,8 +30,7 @@ def clear(self): def mset(self, mapping): """ - Set key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values - should be strings. + Set key/values based on a mapping. Mapping is a dictionary of key/value pairs. """ pass @@ -51,7 +50,7 @@ def status(self) -> bool: def shutdown(self): """ Shutdown the current store. External store resources, for example, initialized redis servers, - will not be shutted down by this method. + will not be shut down by this method. """ pass From 89aed0a559dd1514b016a56c859b4a5ead66b6b2 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 10:45:15 +0800 Subject: [PATCH 36/63] update hash --- bagua/torch_api/contrib/utils/hash_func.py | 52 -------------------- bagua/torch_api/contrib/utils/redis_store.py | 2 +- bagua/torch_api/contrib/utils/store.py | 8 +-- docs/conf.py | 1 - setup.py | 15 ++---- 5 files changed, 9 insertions(+), 69 deletions(-) delete mode 100644 bagua/torch_api/contrib/utils/hash_func.py diff --git a/bagua/torch_api/contrib/utils/hash_func.py b/bagua/torch_api/contrib/utils/hash_func.py deleted file mode 100644 index 188a454c5..000000000 --- a/bagua/torch_api/contrib/utils/hash_func.py +++ /dev/null @@ -1,52 +0,0 @@ -# reference: https://github.com/lammertb/libcrc/blob/master/src/crc16.c - -from typing import Union - - -# fmt: off -table = [ - 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, - 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, - 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, - 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, - 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, - 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, - 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, - 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, - 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, - 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, - 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, - 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, - 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, - 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, - 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, - 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, - 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, - 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, - 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, - 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, - 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, - 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, - 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, - 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, - 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, - 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, - 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, - 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, - 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, - 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, - 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, - 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0 -] - - -def crc16(data: Union[str, bytes]): - if isinstance(data, str): - data = data.encode() - - hash_code = 0x000 - - for i in data: - hash_code = (hash_code >> 8) ^ table[(hash_code ^ i) & 0xFF] - - return hash_code diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 307546462..2ed2eaa02 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -48,7 +48,7 @@ class RedisStore(ClusterStore): redis servers, otherwise, data is routed to a specific server. capacity_per_node (int): Maximum memory limit in bytes to configure redis servers when bootstrap locally. Redis servers will evict keys randomly when maximum memory limit is reached. - hash_fn: Hash function to compute the shard key. Default is `crc16`. A `hash_fn` accepts a `str` or `bytes` as + hash_fn: Hash function to compute the shard key. Default is `xxh64`. A `hash_fn` accepts a `str` as input, and returns an `int` as output. """ diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 18660c37b..563044110 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -59,7 +59,7 @@ class ClusterStore(Store): """ An implementation for a store cluster. - This class implements client side sharding. It uses CRC-16 algorithm to compute the shard key by default, and can + This class implements client side sharding. It uses xxHash algorithm to compute the shard key by default, and can accept customized hashing algorithms by passing `hash_fn` on initialization. key-value pairs are manually added to the cluster using `set()` or `mset()` and can be retrieved by @@ -67,7 +67,7 @@ class ClusterStore(Store): Args: stores(List[Store]): A list of stores in the cluster. - hash_fn: Hash function to compute the shard key. Default is `crc16`. A `hash_fn` accepts a `str` or `bytes` as + hash_fn: Hash function to compute the shard key. Default is `xxh64`. A `hash_fn` accepts a `str` as input, and returns an `int` as output. """ @@ -78,9 +78,9 @@ def __init__(self, stores: List[Store], hash_fn=None): self.num_stores = len(stores) if hash_fn is None: - from .hash_func import crc16 + import xxhash - hash_fn = crc16 + hash_fn = lambda x: xxhash.xxh64(x).intdigest() self.hash_fn = hash_fn def _hash_key(self, key) -> int: diff --git a/docs/conf.py b/docs/conf.py index 7ecaf1319..83440e5a4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,7 +56,6 @@ "*/bagua/torch_api/globals.py", "*/bagua/version.py", "*/bagua/bagua_define.py", - "*/bagua/torch_api/contrib/utils/hash_func.py", ] autoapi_options = [ "members", diff --git a/setup.py b/setup.py index 337f397e2..b041810d2 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,5 @@ import os -from distutils.errors import ( - DistutilsPlatformError, -) +from distutils.errors import DistutilsPlatformError from setuptools import setup, find_packages import sys @@ -47,13 +45,8 @@ def check_torch_version(): "requests", "gorilla", "gevent", + "xxhash==v2.0.2", ], - entry_points={ - "console_scripts": [ - "baguarun = bagua.script.baguarun:main", - ], - }, - scripts=[ - "bagua/script/bagua_sys_perf", - ], + entry_points={"console_scripts": ["baguarun = bagua.script.baguarun:main",],}, + scripts=["bagua/script/bagua_sys_perf",], ) From 351bf23ab58af6d66085e2949c91d0272051566a Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Fri, 13 Aug 2021 10:59:12 +0800 Subject: [PATCH 37/63] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- setup.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index b041810d2..f0fa6e626 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,12 @@ def check_torch_version(): "gevent", "xxhash==v2.0.2", ], - entry_points={"console_scripts": ["baguarun = bagua.script.baguarun:main",],}, - scripts=["bagua/script/bagua_sys_perf",], + entry_points={ + "console_scripts": [ + "baguarun = bagua.script.baguarun:main", + ], + }, + scripts=[ + "bagua/script/bagua_sys_perf", + ], ) From 51b80eee10f2b5b009bb6661a94abac63204230d Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 10:59:47 +0800 Subject: [PATCH 38/63] style --- bagua/torch_api/contrib/utils/store.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 563044110..57e55bc83 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -80,7 +80,11 @@ def __init__(self, stores: List[Store], hash_fn=None): if hash_fn is None: import xxhash - hash_fn = lambda x: xxhash.xxh64(x).intdigest() + def xxh64(x): + return xxhash.xxh64(x).intdigest() + + hash_fn = xxh64 + self.hash_fn = hash_fn def _hash_key(self, key) -> int: From 7a6f357e8833fac09ca3297b7d62c89e896a482f Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 14:29:59 +0800 Subject: [PATCH 39/63] update after review --- bagua/torch_api/contrib/__init__.py | 2 +- bagua/torch_api/contrib/cache_dataset.py | 62 -------------------- bagua/torch_api/contrib/cache_loader.py | 24 ++++---- bagua/torch_api/contrib/cached_dataset.py | 58 ++++++++++++++++++ bagua/torch_api/contrib/utils/redis_store.py | 2 +- tests/contrib/test_cached_dataset.py | 4 +- 6 files changed, 73 insertions(+), 79 deletions(-) delete mode 100644 bagua/torch_api/contrib/cache_dataset.py create mode 100644 bagua/torch_api/contrib/cached_dataset.py diff --git a/bagua/torch_api/contrib/__init__.py b/bagua/torch_api/contrib/__init__.py index 0dfe649d6..40c4c9de9 100644 --- a/bagua/torch_api/contrib/__init__.py +++ b/bagua/torch_api/contrib/__init__.py @@ -4,4 +4,4 @@ LoadBalancingDistributedBatchSampler, ) from .cache_loader import CacheLoader # noqa: F401 -from .cache_dataset import CacheDataset # noqa: F401 +from .cached_dataset import CachedDataset # noqa: F401 diff --git a/bagua/torch_api/contrib/cache_dataset.py b/bagua/torch_api/contrib/cache_dataset.py deleted file mode 100644 index 6b2540795..000000000 --- a/bagua/torch_api/contrib/cache_dataset.py +++ /dev/null @@ -1,62 +0,0 @@ -from torch.utils.data.dataset import Dataset -from .cache_loader import CacheLoader - -__all__ = ["CacheDataset"] - - -class CacheDataset(Dataset): - def __init__( - self, - dataset: Dataset, - backend: str = "redis", - key_prefix: str = "", - batch_writes: int = 20, - **kwargs, - ): - """ - A dataset wrapper which caches `dataset` samples. - - This is useful in scenarios when `dataset` is slow to fetch a sample. - - Args: - dataset: Dataset used for caching. - backend(str): The backend to use. Currently ``"redis"`` is supported. - key_prefix(str): Prefix of the cache key. Default ``""``. - batch_writes(int): How many key-value pairs written to cache once. Default ``20``, If `batch_writes > 1`, the - cache will delay writing non-existed key-value pairs until `batch_writes` key-value pairs are accumulated. - Thus it could combine multiple `set` operations into one `mset` operation, and is expected to reduce - the write latency. - - Example:: - - >>> from bagua.torch_api.contrib import CacheDataset - >>> cache_dataset = CacheDataset( - ... dataset, backend="redis", hosts=None, cluster_mode=False - ... ) - >>> dataloader = torch.utils.data.DataLoader(cached_dataset) - - .. note:: - `CacheDataset` is a special use case of `CacheLoader`, and parameter `backend`, `key_prefix` and `batch_writes` - in `CacheDataset` have the same meanings with those in `CacheLoader`. See :class:`bagua.torch_api.contrib.CacheLoader` - for more information. - - .. note:: - The cache associates dataset indices to determined dataset samples, thus it will break the randomness of - the dataset, if it has. Use :class:`CacheLoader` which can wrap arbitrary data loading logic in this situation. - - """ - - self.dataset = dataset - - self.cache_loader = CacheLoader( - backend, - key_prefix, - batch_writes, - **kwargs, - ) - - def __getitem__(self, item): - return self.cache_loader.get(item, lambda x: self.dataset[x]) - - def __len__(self): - return len(self.dataset) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 44eccbe7d..fbf042837 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -18,7 +18,7 @@ def __init__( self, backend: str = "redis", key_prefix: str = "", - batch_writes: int = 1, + writer_buffer_size: int = 1, **kwargs, ): """ @@ -26,16 +26,14 @@ def __init__( are stored in the cache until evicted. Current backend is "redis". Using "redis" backend, the cache will initialize an instance of :class:`RedisStore` - by a list of initialized redis servers or bootstrapping redis servers locally. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` - for more information. + by a list of initialized redis servers or bootstrapping redis servers locally. See + :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` for more information. Args: - backend(str): The backend to use. Currently ``"redis"`` is supported. + backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. key_prefix(str): Prefix added to the cache key. Default ``""``. - batch_writes(int): How many key-value pairs written to cache once. Default ``1``. If `batch_writes > 1`, the - cache will delay writing non-existed key-value pairs until `batch_writes` key-value pairs are accumulated. - Thus it could combine multiple `set` operations into one `mset` operation, and is expected to reduce - the write latency. + writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. + Useful for improving the backend throughput. Example:: To use "redis" backend and initialized redis clusters: @@ -64,7 +62,7 @@ def __init__( else: raise ValueError('invalid backend, only support "redis" currently') - self.fetcher = BatchFetcher(self.store, 1, batch_writes) + self.fetcher = BatchFetcher(self.store, 1, writer_buffer_size) self.register_shutdown_handler() def get(self, key, load_fn): @@ -92,10 +90,10 @@ def register_shutdown_handler(self): class BatchFetcher: - def __init__(self, store, batch_reads, batch_writes): + def __init__(self, store, read_buffer_size, writer_buffer_size): self.store = store - self.batch_reads = max(1, batch_reads) - self.batch_writes = max(1, batch_writes) + self.read_buffer_size = max(1, read_buffer_size) + self.writer_buffer_size = max(1, writer_buffer_size) self.write_map = defaultdict() self.write_cnt = 0 @@ -121,7 +119,7 @@ def write(self, key, value): self.write_cnt += 1 self.write_map[key] = serialize(value) - if self.write_cnt % self.batch_writes == 0: + if self.write_cnt % self.writer_buffer_size == 0: self.flush_write_map() def write_post_read(self): diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py new file mode 100644 index 000000000..b714abbe2 --- /dev/null +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -0,0 +1,58 @@ +from torch.utils.data.dataset import Dataset +from .cache_loader import CacheLoader + +__all__ = ["CachedDataset"] + + +class CachedDataset(Dataset): + def __init__( + self, + dataset: Dataset, + backend: str = "redis", + dataset_name: str = "", + writer_buffer_size: int = 20, + **kwargs, + ): + """ + CachedDataset wraps a PyTorch Dataset to cache its samples in memory, so that accessing these samples after the + first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading + the dataset itself is slow, which could slow down the whole training process. + + Internally, the samples are indexed by a key ``"{dataset_name}_{index}"`` and saved in a distributed Key-Value + store, where ``dataset_name`` is specified when initializing the CachedDataset, and ``index`` is the index + of a specific sample (the argument of `__getitem__(...)` method in a PyTorch Dataset). + + Args: + dataset: PyTorch Dataset to be wrapped. + backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. + dataset_name(str): Name of the dataset. Default ``""``. + writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. + Useful for improving the backend throughput. + + Example:: + + >>> from bagua.torch_api.contrib import CachedDataset + >>> cache_dataset = CachedDataset(dataset, backend="redis") + >>> dataloader = torch.utils.data.DataLoader(cached_dataset) + + .. note:: + `CachedDataset` is a special case of `CacheLoader`, and parameter `backend`, `key_prefix` and `writer_buffer_size` + in `CachedDataset` have the same meanings with those in `CacheLoader`. See :class:`bagua.torch_api.contrib.CacheLoader` + for more information. + + """ + + self.dataset = dataset + + self.cache_loader = CacheLoader( + backend, + dataset_name, + writer_buffer_size, + **kwargs, + ) + + def __getitem__(self, item): + return self.cache_loader.get(item, lambda x: self.dataset[x]) + + def __len__(self): + return len(self.dataset) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 2ed2eaa02..9f8f51b10 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -56,7 +56,7 @@ class RedisStore(ClusterStore): def __init__( self, hosts: Optional[List[Dict[str, str]]] = None, - cluster_mode: bool = False, + cluster_mode: bool = True, capacity_per_node: int = 100_000_000_000, hash_fn=None, ): diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 83df1f53a..6ca7fb42f 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -1,4 +1,4 @@ -from bagua.torch_api.contrib.cache_dataset import CacheDataset +from bagua.torch_api.contrib.cached_dataset import CachedDataset from torch.utils.data.dataset import Dataset import numpy as np import logging @@ -32,7 +32,7 @@ def check_dataset(self, dataset, cache_dataset): def test_redis(self): dataset = MyDataset(102) - cache_dataset = CacheDataset( + cache_dataset = CachedDataset( dataset, backend="redis", hosts=None, cluster_mode=False ) self.check_dataset(dataset, cache_dataset) From 489464236c51b1252fb48953a1593c72ffcf6329 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 15:06:27 +0800 Subject: [PATCH 40/63] doc --- bagua/torch_api/contrib/cache_loader.py | 15 ++++++++++----- bagua/torch_api/contrib/cached_dataset.py | 4 ++-- bagua/torch_api/contrib/utils/redis_store.py | 2 +- bagua/torch_api/contrib/utils/store.py | 6 +++--- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index fbf042837..659146821 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -25,9 +25,12 @@ def __init__( A mapping from keys to values. Values are automatically loaded by the cache, and are stored in the cache until evicted. - Current backend is "redis". Using "redis" backend, the cache will initialize an instance of :class:`RedisStore` - by a list of initialized redis servers or bootstrapping redis servers locally. See - :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` for more information. + Internally, values are indexed by ``"{key_prefix}_{key}"`` and saved in a distributed Key-Value + store, where ``key_prefix`` is specified on initializing, and ``key`` is the argument in :func:`get`. + + By default, CacheLoader uses :class:`RedisStore` as its backend distributed Key-Value store implementation. It + could reuse a list of initialized redis servers or bootstrap local redis servers by itself. See + :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` for further customization. Args: backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. @@ -36,20 +39,22 @@ def __init__( Useful for improving the backend throughput. Example:: - To use "redis" backend and initialized redis clusters: + To reuse a list of initialized redis servers for "redis" backend: >>> from bagua.torch_api.contrib import CacheLoader + >>> >>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}] >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, key_prefix="test") >>> >>> loader.get(index, lambda x: items[x]) - To use "redis" backend and bootstrap redis servers locally: + To bootstrap local redis servers for "redis" backend: >>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000) .. note:: Setting a specific `key_prefix` can be useful to avoid overwriting existing cache data. + """ self.backend = backend diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index b714abbe2..9750e8246 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -37,8 +37,8 @@ def __init__( .. note:: `CachedDataset` is a special case of `CacheLoader`, and parameter `backend`, `key_prefix` and `writer_buffer_size` - in `CachedDataset` have the same meanings with those in `CacheLoader`. See :class:`bagua.torch_api.contrib.CacheLoader` - for more information. + in `CachedDataset` have the same meanings with those in `CacheLoader`. Further customization can be found in + :class:`bagua.torch_api.contrib.CacheLoader`. """ diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 9f8f51b10..da0cd0426 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -36,7 +36,7 @@ class RedisStore(ClusterStore): """ - A Redis-based store implementation. + A Redis-based Key-Value store implementation. The server holds the data, while the client can connect to the server over Redis protocol and perform actions such as `set()` to insert a key-value pair, `get()` to retrieve a key-value pair, etc. diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 57e55bc83..6101f970f 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -7,7 +7,7 @@ class Store: """ - Base class for all store implementations. A store keeps a mapping from keys to values. + Base class for all Key-Value store implementations. A store keeps a mapping from keys to values. key-value pairs are manually added to store using `set()` or `mset()` and can be retrieved by `get()` or `mget()`. """ @@ -57,9 +57,9 @@ def shutdown(self): class ClusterStore(Store): """ - An implementation for a store cluster. + Base class for distributed Key-Value stores. - This class implements client side sharding. It uses xxHash algorithm to compute the shard key by default, and can + This class implements client side sharding. It uses **xxHash** algorithm to compute the shard key by default, and can accept customized hashing algorithms by passing `hash_fn` on initialization. key-value pairs are manually added to the cluster using `set()` or `mset()` and can be retrieved by From b262801a8f6573328cbccf38f7abd9cc4b0c5e4f Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 15:09:43 +0800 Subject: [PATCH 41/63] d --- bagua/torch_api/contrib/cached_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 9750e8246..432fa186d 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -36,7 +36,7 @@ def __init__( >>> dataloader = torch.utils.data.DataLoader(cached_dataset) .. note:: - `CachedDataset` is a special case of `CacheLoader`, and parameter `backend`, `key_prefix` and `writer_buffer_size` + `CachedDataset` is a special case of `CacheLoader`, and parameter `backend` and `writer_buffer_size` in `CachedDataset` have the same meanings with those in `CacheLoader`. Further customization can be found in :class:`bagua.torch_api.contrib.CacheLoader`. From 15f2977a8130db29ed70f02d0ff18b7210f6f100 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 15:57:14 +0800 Subject: [PATCH 42/63] share redis server --- bagua/torch_api/contrib/utils/redis_store.py | 68 +++++++++++++------- tests/contrib/test_cached_dataset.py | 16 +++-- 2 files changed, 56 insertions(+), 28 deletions(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 9f8f51b10..8349eb235 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -32,11 +32,12 @@ __all__ = ["RedisStore"] _host_ip = None +_bootstrap_redis_hosts = [] class RedisStore(ClusterStore): """ - A Redis-based store implementation. + A Redis-based Key-Value store implementation. The server holds the data, while the client can connect to the server over Redis protocol and perform actions such as `set()` to insert a key-value pair, `get()` to retrieve a key-value pair, etc. @@ -51,6 +52,9 @@ class RedisStore(ClusterStore): hash_fn: Hash function to compute the shard key. Default is `xxh64`. A `hash_fn` accepts a `str` as input, and returns an `int` as output. + .. note:: + Only one redis server can be bootstrapped on each node, thus the maximum memory limit of it is determined on + its first initialization. """ def __init__( @@ -74,7 +78,8 @@ def __init__( self.capacity_per_node = capacity_per_node if self.bootstrap: - self._bootstrap_redis_server() + bootstrap_redis_server(self.capacity_per_node) + self.hosts.extend(get_bootstrapped_redis_server(self.cluster_mode)) stores = [] for h in self.hosts: @@ -85,32 +90,49 @@ def __init__( super(RedisStore, self).__init__(stores, hash_fn) - def _bootstrap_redis_server(self): - ip, port = get_host_ip(), find_free_port() - hostinfo = {"host": ip, "port": port} + +def _is_bootstrapped(): + global _bootstrap_redis_hosts + + return _bootstrap_redis_hosts is not None and len(_bootstrap_redis_hosts) > 0 + + +def bootstrap_redis_server(capacity_per_node): + if _is_bootstrapped(): + logging.debug("local redis server is already bootstrapped") + return + + ip, port = get_host_ip(), find_free_port() + hostinfo = {"host": ip, "port": port} + if get_local_rank() == 0: + start_redis_server_cli(port, capacity_per_node) + + hosts = [] + nrank = get_rank() // get_local_size() + if get_world_size() > 1: + nnodes = get_world_size() // get_local_size() + default_store = c10d._get_default_store() + key_pattern = "redis-node{}" + if get_local_rank() == 0: - start_redis_server_cli(port, self.capacity_per_node) + default_store.set(key_pattern.format(nrank), json.dumps(hostinfo)) - hosts = [] - nrank = get_rank() // get_local_size() - if get_world_size() > 1: - nnodes = get_world_size() // get_local_size() - default_store = c10d._get_default_store() - key_pattern = "redis-node{}" + for i in range(nnodes): + ret = json.loads(default_store.get(key_pattern.format(i))) + hosts.append(ret) + else: + hosts.append(hostinfo) - if get_local_rank() == 0: - default_store.set(key_pattern.format(nrank), json.dumps(hostinfo)) + global _bootstrap_redis_hosts + _bootstrap_redis_hosts.extend(hosts) - for i in range(nnodes): - ret = json.loads(default_store.get(key_pattern.format(i))) - hosts.append(ret) - else: - hosts.append(hostinfo) - if self.cluster_mode: - self.hosts.extend(hosts) - else: - self.hosts.append(hosts[nrank]) +def get_bootstrapped_redis_server(cluster_mode): + if cluster_mode: + return _bootstrap_redis_hosts + else: + nrank = get_rank() // get_local_size() + return [_bootstrap_redis_hosts[nrank]] class _RedisStore(Store): diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index 6ca7fb42f..c07eb4f20 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -25,17 +25,23 @@ def check_dataset(self, dataset, cache_dataset): for _, _ in enumerate(cache_dataset): pass - self.assertEqual(cache_dataset.cache_loader.num_keys(), len(dataset)) for i in range(len(dataset)): self.assertTrue((dataset[i][0] == cache_dataset[i][0]).all()) self.assertTrue((dataset[i][1] == cache_dataset[i][1]).all()) def test_redis(self): - dataset = MyDataset(102) - cache_dataset = CachedDataset( - dataset, backend="redis", hosts=None, cluster_mode=False + dataset1 = MyDataset(102) + dataset2 = MyDataset(102) + cache_dataset1 = CachedDataset(dataset1, backend="redis", dataset_name="d1",) + cache_dataset2 = CachedDataset(dataset2, backend="redis", dataset_name="d2",) + + self.check_dataset(dataset1, cache_dataset1) + self.assertEqual(cache_dataset1.cache_loader.num_keys(), len(dataset1)) + + self.check_dataset(dataset2, cache_dataset2) + self.assertEqual( + cache_dataset2.cache_loader.num_keys(), len(dataset1) + len(dataset2) ) - self.check_dataset(dataset, cache_dataset) if __name__ == "__main__": From 0027c97d089d4e0ead6fb98ecd8edc0abe246a41 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 16:28:25 +0800 Subject: [PATCH 43/63] . --- bagua/torch_api/contrib/cache_loader.py | 2 +- bagua/torch_api/contrib/cached_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 659146821..6daabe782 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -34,7 +34,7 @@ def __init__( Args: backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. - key_prefix(str): Prefix added to the cache key. Default ``""``. + key_prefix(str): Prefix added to the cache key. Better to be short. Default ``""``. writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. Useful for improving the backend throughput. diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 432fa186d..c7a4de9cf 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -25,14 +25,14 @@ def __init__( Args: dataset: PyTorch Dataset to be wrapped. backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. - dataset_name(str): Name of the dataset. Default ``""``. + dataset_name(str): Name of the dataset. Better to be short. Default ``""``. writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. Useful for improving the backend throughput. Example:: >>> from bagua.torch_api.contrib import CachedDataset - >>> cache_dataset = CachedDataset(dataset, backend="redis") + >>> cache_dataset = CachedDataset(dataset, backend="redis", dataset_name="ds") >>> dataloader = torch.utils.data.DataLoader(cached_dataset) .. note:: From 32c066ddf09d16355c9f291f9a90d4652e909b4e Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Fri, 13 Aug 2021 16:29:10 +0800 Subject: [PATCH 44/63] Update tests/contrib/test_cached_dataset.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/contrib/test_cached_dataset.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index c07eb4f20..c0b172472 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -32,8 +32,16 @@ def check_dataset(self, dataset, cache_dataset): def test_redis(self): dataset1 = MyDataset(102) dataset2 = MyDataset(102) - cache_dataset1 = CachedDataset(dataset1, backend="redis", dataset_name="d1",) - cache_dataset2 = CachedDataset(dataset2, backend="redis", dataset_name="d2",) + cache_dataset1 = CachedDataset( + dataset1, + backend="redis", + dataset_name="d1", + ) + cache_dataset2 = CachedDataset( + dataset2, + backend="redis", + dataset_name="d2", + ) self.check_dataset(dataset1, cache_dataset1) self.assertEqual(cache_dataset1.cache_loader.num_keys(), len(dataset1)) From c9022bc9580c844af563e350c84ce95566e212a7 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 09:52:47 +0000 Subject: [PATCH 45/63] devdev --- bagua/torch_api/contrib/cache_loader.py | 5 --- bagua/torch_api/contrib/utils/redis_store.py | 38 ++++++++++++++------ docs/conf.py | 1 - tests/contrib/test_store.py | 6 ++-- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 6daabe782..b35fa3ca7 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -1,6 +1,5 @@ import pickle from collections import defaultdict -import atexit __all__ = ["CacheLoader"] @@ -68,7 +67,6 @@ def __init__( raise ValueError('invalid backend, only support "redis" currently') self.fetcher = BatchFetcher(self.store, 1, writer_buffer_size) - self.register_shutdown_handler() def get(self, key, load_fn): """ @@ -90,9 +88,6 @@ def num_keys(self): return self.store.num_keys() - def register_shutdown_handler(self): - atexit.register(self.store.shutdown) - class BatchFetcher: def __init__(self, store, read_buffer_size, writer_buffer_size): diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 8349eb235..dec6f30c1 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -16,6 +16,7 @@ import torch.distributed.distributed_c10d as c10d import json import logging +import atexit try: @@ -32,7 +33,8 @@ __all__ = ["RedisStore"] _host_ip = None -_bootstrap_redis_hosts = [] + +_global_redis_servers = [] class RedisStore(ClusterStore): @@ -79,7 +81,7 @@ def __init__( if self.bootstrap: bootstrap_redis_server(self.capacity_per_node) - self.hosts.extend(get_bootstrapped_redis_server(self.cluster_mode)) + self.hosts.extend(get_bootstrapped_host_info(self.cluster_mode)) stores = [] for h in self.hosts: @@ -92,14 +94,23 @@ def __init__( def _is_bootstrapped(): - global _bootstrap_redis_hosts + global _global_redis_servers + + return _global_redis_servers is not None and len(_global_redis_servers) > 0 + + +def shutdown_redis_server(): + global _global_redis_servers - return _bootstrap_redis_hosts is not None and len(_bootstrap_redis_hosts) > 0 + hostinfo = get_bootstrapped_host_info(cluster_mode=False)[0] + store = _RedisStore(host=hostinfo["host"], port=hostinfo["port"], bootstrap=True) + + store.shutdown() def bootstrap_redis_server(capacity_per_node): if _is_bootstrapped(): - logging.debug("local redis server is already bootstrapped") + logging.debug("local redis server has already bootstrapped") return ip, port = get_host_ip(), find_free_port() @@ -123,21 +134,27 @@ def bootstrap_redis_server(capacity_per_node): else: hosts.append(hostinfo) - global _bootstrap_redis_hosts - _bootstrap_redis_hosts.extend(hosts) + global _global_redis_servers + _global_redis_servers.extend(hosts) + + atexit.register(shutdown_redis_server) + +def get_bootstrapped_host_info(cluster_mode): + global _global_redis_servers -def get_bootstrapped_redis_server(cluster_mode): if cluster_mode: - return _bootstrap_redis_hosts + return _global_redis_servers else: nrank = get_rank() // get_local_size() - return [_bootstrap_redis_hosts[nrank]] + return [_global_redis_servers[nrank]] class _RedisStore(Store): def __init__(self, host, port, bootstrap): self.client = create_redis_client(host=host, port=port) + self.host = host + self.port = port self.bootstrap = bootstrap assert self._connect_with_retry( @@ -181,6 +198,7 @@ def status(self) -> bool: def shutdown(self): if self.bootstrap: + logging.debug(f"shutting down local redis server at port {self.port}") self.client.shutdown(nosave=True) # pytype: disable=wrong-keyword-args diff --git a/docs/conf.py b/docs/conf.py index 83440e5a4..ac239f13a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -127,7 +127,6 @@ "bagua.torch_api.contrib.*LoadBalancingDistributedSampler.shuffle_chunks", "bagua.torch_api.contrib.*LoadBalancingDistributedBatchSampler.generate_batches", "bagua.torch_api.contrib.utils.store.ClusterStore.*", - "bagua.torch_api.contrib.*CacheLoader.register_shutdown_handler", ] _ignore_functions = [ "bagua.torch_api.env.get_autotune_server_addr", diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index 69f36e31f..19350e1f3 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -44,9 +44,6 @@ def check(self, store): self.assertTrue(store.status()) - # try to shut down resources - store.shutdown() - def test_redis_store(self): store = RedisStore(hosts=None, cluster_mode=False, capacity_per_node=10000000) self.check(store) @@ -74,9 +71,10 @@ def test_redis_cluster_store(self): store = RedisStore(hosts=hosts, cluster_mode=True) self.check(store) + store.shutdown() self.assertTrue(store.status()) - # Now shut down servers safely + # Now shut down servers manually for port in ports: client = redis.Redis(port=port) client.shutdown(nosave=True) From 1e21deb3262b705d6eff0d2741f8d07aea58c68d Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 20:02:20 +0800 Subject: [PATCH 46/63] update --- bagua/torch_api/contrib/cached_dataset.py | 8 ++++---- bagua/torch_api/contrib/utils/redis_store.py | 17 +++++++++-------- tests/contrib/test_cached_dataset.py | 14 ++++---------- tests/contrib/test_store.py | 11 ++++------- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index c7a4de9cf..ad2735255 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -45,11 +45,11 @@ def __init__( self.dataset = dataset self.cache_loader = CacheLoader( - backend, - dataset_name, - writer_buffer_size, - **kwargs, + backend, dataset_name, writer_buffer_size, **kwargs, ) + """ + The backend cache instance. + """ def __getitem__(self, item): return self.cache_loader.get(item, lambda x: self.dataset[x]) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index dec6f30c1..a9c81f2f9 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -85,9 +85,7 @@ def __init__( stores = [] for h in self.hosts: - store = _RedisStore( - host=h["host"], port=h["port"], bootstrap=self.bootstrap - ) + store = _RedisStore(host=h["host"], port=h["port"]) stores.append(store) super(RedisStore, self).__init__(stores, hash_fn) @@ -103,7 +101,7 @@ def shutdown_redis_server(): global _global_redis_servers hostinfo = get_bootstrapped_host_info(cluster_mode=False)[0] - store = _RedisStore(host=hostinfo["host"], port=hostinfo["port"], bootstrap=True) + store = _RedisStore(host=hostinfo["host"], port=hostinfo["port"]) store.shutdown() @@ -151,11 +149,10 @@ def get_bootstrapped_host_info(cluster_mode): class _RedisStore(Store): - def __init__(self, host, port, bootstrap): + def __init__(self, host, port): self.client = create_redis_client(host=host, port=port) self.host = host self.port = port - self.bootstrap = bootstrap assert self._connect_with_retry( retry_times=3 @@ -197,8 +194,12 @@ def status(self) -> bool: return self.client.ping() def shutdown(self): - if self.bootstrap: - logging.debug(f"shutting down local redis server at port {self.port}") + if self.host != get_host_ip(): + logging.error(f"Could not shut down non-local redis servers.") + else: + logging.debug( + f"CLEANUP: shutting down local redis server at port {self.port}." + ) self.client.shutdown(nosave=True) # pytype: disable=wrong-keyword-args diff --git a/tests/contrib/test_cached_dataset.py b/tests/contrib/test_cached_dataset.py index c0b172472..ebe619fc4 100644 --- a/tests/contrib/test_cached_dataset.py +++ b/tests/contrib/test_cached_dataset.py @@ -32,16 +32,10 @@ def check_dataset(self, dataset, cache_dataset): def test_redis(self): dataset1 = MyDataset(102) dataset2 = MyDataset(102) - cache_dataset1 = CachedDataset( - dataset1, - backend="redis", - dataset_name="d1", - ) - cache_dataset2 = CachedDataset( - dataset2, - backend="redis", - dataset_name="d2", - ) + cache_dataset1 = CachedDataset(dataset1, backend="redis", dataset_name="d1",) + cache_dataset2 = CachedDataset(dataset2, backend="redis", dataset_name="d2",) + + cache_dataset1.cache_loader.store.clear() self.check_dataset(dataset1, cache_dataset1) self.assertEqual(cache_dataset1.cache_loader.num_keys(), len(dataset1)) diff --git a/tests/contrib/test_store.py b/tests/contrib/test_store.py index 19350e1f3..bb3d091c0 100644 --- a/tests/contrib/test_store.py +++ b/tests/contrib/test_store.py @@ -16,6 +16,9 @@ class TestRedisStore(unittest.TestCase): def check(self, store): + store.clear() + self.assertEqual(store.num_keys(), 0) + self.generated_data = [np.random.rand(10) for _ in range(5)] store.set("1", pickle.dumps(self.generated_data[1])) @@ -39,9 +42,6 @@ def check(self, store): cnt = store.num_keys() self.assertEqual(cnt, 4) - store.clear() - self.assertEqual(store.num_keys(), 0) - self.assertTrue(store.status()) def test_redis_store(self): @@ -71,10 +71,7 @@ def test_redis_cluster_store(self): store = RedisStore(hosts=hosts, cluster_mode=True) self.check(store) - store.shutdown() - self.assertTrue(store.status()) - - # Now shut down servers manually + # Shut down servers manually for port in ports: client = redis.Redis(port=port) client.shutdown(nosave=True) From 792ab7db4e9d49029f0da9470724d35cb547bb8b Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 13:06:14 +0000 Subject: [PATCH 47/63] . --- bagua/torch_api/contrib/cache_loader.py | 2 +- bagua/torch_api/contrib/utils/redis_store.py | 70 +++++++++----------- 2 files changed, 32 insertions(+), 40 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index b35fa3ca7..29c3ab9ff 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -64,7 +64,7 @@ def __init__( self.store = RedisStore(**kwargs) else: - raise ValueError('invalid backend, only support "redis" currently') + raise ValueError('Invalid backend, only support "redis" currently') self.fetcher = BatchFetcher(self.store, 1, writer_buffer_size) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index a9c81f2f9..2b2d4850c 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -67,24 +67,21 @@ def __init__( hash_fn=None, ): - self.hosts = [] if hosts is None: - logging.info("Ready to bootstrap redis server locally") - self.bootstrap = True + logging.info("Ready to bootstrap redis server locally.") + hosts = bootstrap_redis_server(capacity_per_node) else: + assert len(hosts) > 0, "RedisStore hosts should not be empty." logging.info("Ready to connect redis servers: {}".format(hosts)) - self.bootstrap = False - self.hosts.extend(hosts) - self.cluster_mode = cluster_mode - self.capacity_per_node = capacity_per_node - - if self.bootstrap: - bootstrap_redis_server(self.capacity_per_node) - self.hosts.extend(get_bootstrapped_host_info(self.cluster_mode)) + to_connect = [] + if cluster_mode: + to_connect.extend(hosts) + else: + to_connect.append(hosts[get_node_rank() % len(hosts)]) stores = [] - for h in self.hosts: + for h in to_connect: store = _RedisStore(host=h["host"], port=h["port"]) stores.append(store) @@ -100,52 +97,39 @@ def _is_bootstrapped(): def shutdown_redis_server(): global _global_redis_servers - hostinfo = get_bootstrapped_host_info(cluster_mode=False)[0] + hostinfo = _global_redis_servers[get_node_rank() % len(_global_redis_servers)] store = _RedisStore(host=hostinfo["host"], port=hostinfo["port"]) store.shutdown() def bootstrap_redis_server(capacity_per_node): + global _global_redis_servers + if _is_bootstrapped(): - logging.debug("local redis server has already bootstrapped") - return + logging.debug("Local redis server has already bootstrapped.") + return _global_redis_servers - ip, port = get_host_ip(), find_free_port() - hostinfo = {"host": ip, "port": port} + host, port = get_host_ip(), find_free_port() + hostinfo = {"host": host, "port": port} if get_local_rank() == 0: start_redis_server_cli(port, capacity_per_node) + atexit.register(shutdown_redis_server) - hosts = [] - nrank = get_rank() // get_local_size() if get_world_size() > 1: - nnodes = get_world_size() // get_local_size() default_store = c10d._get_default_store() key_pattern = "redis-node{}" if get_local_rank() == 0: - default_store.set(key_pattern.format(nrank), json.dumps(hostinfo)) + default_store.set(key_pattern.format(get_node_rank()), json.dumps(hostinfo)) - for i in range(nnodes): + for i in range(get_num_nodes()): ret = json.loads(default_store.get(key_pattern.format(i))) - hosts.append(ret) + _global_redis_servers.append(ret) else: - hosts.append(hostinfo) + _global_redis_servers.append(hostinfo) - global _global_redis_servers - _global_redis_servers.extend(hosts) - - atexit.register(shutdown_redis_server) - - -def get_bootstrapped_host_info(cluster_mode): - global _global_redis_servers - - if cluster_mode: - return _global_redis_servers - else: - nrank = get_rank() // get_local_size() - return [_global_redis_servers[nrank]] + return _global_redis_servers class _RedisStore(Store): @@ -221,10 +205,18 @@ def start_redis_server_cli(port, capacity, *args): ] cmd.extend(list(args)) - logging.debug(f"start redis server, command: {cmd}") + logging.debug(f"Start redis server, command: {cmd}") subprocess.run(cmd) +def get_node_rank(): + return get_rank() // get_local_size() + + +def get_num_nodes(): + return get_world_size() // get_local_size() + + def find_free_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) From 4c0b046bd2024d1d2a3d7e0f9dbf4c9a131fde9d Mon Sep 17 00:00:00 2001 From: wangraying <45031995+wangraying@users.noreply.github.com> Date: Fri, 13 Aug 2021 21:11:23 +0800 Subject: [PATCH 48/63] Update bagua/torch_api/contrib/cached_dataset.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- bagua/torch_api/contrib/cached_dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index ad2735255..012e10b70 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -45,7 +45,10 @@ def __init__( self.dataset = dataset self.cache_loader = CacheLoader( - backend, dataset_name, writer_buffer_size, **kwargs, + backend, + dataset_name, + writer_buffer_size, + **kwargs, ) """ The backend cache instance. From cd8afbefc7e9b8a3d0e42a251482ebf72abb5ea7 Mon Sep 17 00:00:00 2001 From: ritaw Date: Fri, 13 Aug 2021 13:13:50 +0000 Subject: [PATCH 49/63] fmt --- bagua/torch_api/contrib/utils/redis_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index 2b2d4850c..b0e86be40 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -179,7 +179,7 @@ def status(self) -> bool: def shutdown(self): if self.host != get_host_ip(): - logging.error(f"Could not shut down non-local redis servers.") + logging.error("Could not shut down non-local redis servers.") else: logging.debug( f"CLEANUP: shutting down local redis server at port {self.port}." From 0f49c2e398bc53aac320d872a4bd4449ba620ae6 Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 16 Aug 2021 13:55:13 +0800 Subject: [PATCH 50/63] . --- bagua/torch_api/contrib/cache_loader.py | 2 +- bagua/torch_api/contrib/cached_dataset.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 29c3ab9ff..744cb7a77 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -33,7 +33,7 @@ def __init__( Args: backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. - key_prefix(str): Prefix added to the cache key. Better to be short. Default ``""``. + key_prefix(str): Prefix added to the cache key. Default ``""``. writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. Useful for improving the backend throughput. diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 012e10b70..33352d047 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -14,18 +14,18 @@ def __init__( **kwargs, ): """ - CachedDataset wraps a PyTorch Dataset to cache its samples in memory, so that accessing these samples after the + `CachedDataset` wraps a PyTorch Dataset to cache its samples in memory, so that accessing these samples after the first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading the dataset itself is slow, which could slow down the whole training process. Internally, the samples are indexed by a key ``"{dataset_name}_{index}"`` and saved in a distributed Key-Value - store, where ``dataset_name`` is specified when initializing the CachedDataset, and ``index`` is the index + store, where ``dataset_name`` is specified when initializing the `CachedDataset`, and ``index`` is the index of a specific sample (the argument of `__getitem__(...)` method in a PyTorch Dataset). Args: dataset: PyTorch Dataset to be wrapped. backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. - dataset_name(str): Name of the dataset. Better to be short. Default ``""``. + dataset_name(str): Name of the dataset. Default ``""``. writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. Useful for improving the backend throughput. From ec684a80f8c6c5c55d1bbe63a46a19a15b654322 Mon Sep 17 00:00:00 2001 From: ritaw Date: Thu, 19 Aug 2021 10:15:13 +0800 Subject: [PATCH 51/63] add --- bagua/torch_api/contrib/cache_loader.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 744cb7a77..48d326f39 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -16,24 +16,24 @@ class CacheLoader: def __init__( self, backend: str = "redis", - key_prefix: str = "", + dataset_name: str = "", writer_buffer_size: int = 1, **kwargs, ): """ - A mapping from keys to values. Values are automatically loaded by the cache, and - are stored in the cache until evicted. + `CacheLoader` caches values calculated by an expensive function by theirs keys via :func:`CacheLoader.get` method, + so that the values can be retrieved faster next time. - Internally, values are indexed by ``"{key_prefix}_{key}"`` and saved in a distributed Key-Value - store, where ``key_prefix`` is specified on initializing, and ``key`` is the argument in :func:`get`. + Internally, values are indexed by ``"{dataset_name}_{key}"`` and saved in a distributed Key-Value + store, where ``dataset_name`` is specified on initializing, and ``key`` is the argument in :func:`CacheLoader.get`. - By default, CacheLoader uses :class:`RedisStore` as its backend distributed Key-Value store implementation. It + By default, `CacheLoader` uses :class:`RedisStore` as its backend distributed Key-Value store implementation. It could reuse a list of initialized redis servers or bootstrap local redis servers by itself. See :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` for further customization. Args: backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. - key_prefix(str): Prefix added to the cache key. Default ``""``. + dataset_name(str): Name of the dataset. Default ``""``. writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. Useful for improving the backend throughput. @@ -43,7 +43,7 @@ def __init__( >>> from bagua.torch_api.contrib import CacheLoader >>> >>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}] - >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, key_prefix="test") + >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, dataset_name="test") >>> >>> loader.get(index, lambda x: items[x]) @@ -52,12 +52,12 @@ def __init__( >>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000) .. note:: - Setting a specific `key_prefix` can be useful to avoid overwriting existing cache data. + Setting a specific `dataset_name` can be useful to avoid overwriting existing cache data. """ self.backend = backend - self.key_prefix = key_prefix + self.dataset_name = dataset_name if backend == "redis": from .utils.redis_store import RedisStore @@ -74,7 +74,7 @@ def get(self, key, load_fn): `load_fn` accepts `key` as input, and returns the data to be serialized and stored. """ - cache_key = "{}{}".format(self.key_prefix, key) + cache_key = "{}{}".format(self.dataset_name, key) ret = self.fetcher.read(cache_key) if ret is None: From 09d4f4521fb23cbb0c75e2d359e4ba959f9fdb4b Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 00:01:12 -0700 Subject: [PATCH 52/63] Update cache_loader.py --- bagua/torch_api/contrib/cache_loader.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 48d326f39..21cb6a0e8 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -28,8 +28,9 @@ def __init__( store, where ``dataset_name`` is specified on initializing, and ``key`` is the argument in :func:`CacheLoader.get`. By default, `CacheLoader` uses :class:`RedisStore` as its backend distributed Key-Value store implementation. It - could reuse a list of initialized redis servers or bootstrap local redis servers by itself. See - :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore` for further customization. + supports using a list of existing redis servers or spawning new redis servers. See also + :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore`. Parameters for `RedisStore` can be provided here in + ``**kwargs``. Args: backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. @@ -38,7 +39,7 @@ def __init__( Useful for improving the backend throughput. Example:: - To reuse a list of initialized redis servers for "redis" backend: + To use a list of existing redis servers for the "redis" backend: >>> from bagua.torch_api.contrib import CacheLoader >>> @@ -47,12 +48,13 @@ def __init__( >>> >>> loader.get(index, lambda x: items[x]) - To bootstrap local redis servers for "redis" backend: + To spawn new redis servers for the "redis" backend: >>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000) .. note:: - Setting a specific `dataset_name` can be useful to avoid overwriting existing cache data. + ``CacheLoader``s with the same ``dataset_name`` will reuse and overwrite each other's cache. + Use different ``dataset_name`` if this is not desired. """ @@ -70,8 +72,9 @@ def __init__( def get(self, key, load_fn): """ - Returns the value associated with key in cache, first loading the value if necessary. - `load_fn` accepts `key` as input, and returns the data to be serialized and stored. + Returns the value associated with key in cache, use ``load_fn`` to create the entry if the key does not exist + in the cache. ``load_fn`` is a function taking ``key`` as its argument, and returning corresponding value to + be cached. """ cache_key = "{}{}".format(self.dataset_name, key) @@ -84,7 +87,7 @@ def get(self, key, load_fn): return ret def num_keys(self): - """Returns the total number of keys in cache""" + """Returns the number of keys in the cache.""" return self.store.num_keys() From a86eb959a3d59e7534842e3ffdaee6a79d62619f Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 00:02:49 -0700 Subject: [PATCH 53/63] Update cached_dataset.py --- bagua/torch_api/contrib/cached_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 33352d047..0fb48448e 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -37,8 +37,8 @@ def __init__( .. note:: `CachedDataset` is a special case of `CacheLoader`, and parameter `backend` and `writer_buffer_size` - in `CachedDataset` have the same meanings with those in `CacheLoader`. Further customization can be found in - :class:`bagua.torch_api.contrib.CacheLoader`. + in `CachedDataset` have the same meanings with those in `CacheLoader`. You can provide ``CacheLoader``'s + argument here in ``**kwargs``. See also :class:`bagua.torch_api.contrib.CacheLoader`. """ From 3291b422a61fc331dd9410b08dbfc5260b75f67f Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 00:03:23 -0700 Subject: [PATCH 54/63] Update cache_loader.py --- bagua/torch_api/contrib/cache_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 21cb6a0e8..dd37631e6 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -77,7 +77,7 @@ def get(self, key, load_fn): be cached. """ - cache_key = "{}{}".format(self.dataset_name, key) + cache_key = "{}_{}".format(self.dataset_name, key) ret = self.fetcher.read(cache_key) if ret is None: From 5a6714894ab8e63cf46619d740030f3214b996b1 Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 00:13:03 -0700 Subject: [PATCH 55/63] Update redis_store.py --- bagua/torch_api/contrib/utils/redis_store.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index b0e86be40..f5c2999a7 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -39,24 +39,17 @@ class RedisStore(ClusterStore): """ - A Redis-based Key-Value store implementation. - - The server holds the data, while the client can connect to the server over Redis protocol and perform - actions such as `set()` to insert a key-value pair, `get()` to retrieve a key-value pair, etc. + A Redis-based distributed Key-Value store implementation, with ``set(...)`` and ``get(...)``` API exposed. Args: hosts (List[Dict[str, str]]): A list of redis servers, defined by a list of dict containing server host and - port information. Can be ``None``, which means to bootstrap redis servers locally. - cluster_mode (bool): Redis servers serve as a cluster or not. If True, data is automatically sharded across all - redis servers, otherwise, data is routed to a specific server. - capacity_per_node (int): Maximum memory limit in bytes to configure redis servers when bootstrap locally. Redis servers - will evict keys randomly when maximum memory limit is reached. - hash_fn: Hash function to compute the shard key. Default is `xxh64`. A `hash_fn` accepts a `str` as - input, and returns an `int` as output. + port information. New Redis instances will be spawned if ``hosts=None``. + cluster_mode (bool): If ``True``, data is sharded across all Redis instances. Otherwise, data is routed to a specific server. + capacity_per_node (int): Maximum memory limit in bytes when spawning new Redis instances. Old values will be evicted when the limit is reached. + hash_fn: Hash function to determine which shard a key belongs to. ``hash_fn`` accepts a ``str`` and returns an ``int`` as output. .. note:: - Only one redis server can be bootstrapped on each node, thus the maximum memory limit of it is determined on - its first initialization. + All Bagua jobs will share the same local Redis instance if ``hosts=None``. """ def __init__( From 4553c3b4d0df9be1d047973ac4e4643943df3758 Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 00:17:40 -0700 Subject: [PATCH 56/63] Update store.py --- bagua/torch_api/contrib/utils/store.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 6101f970f..1119d347c 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -7,9 +7,8 @@ class Store: """ - Base class for all Key-Value store implementations. A store keeps a mapping from keys to values. - key-value pairs are manually added to store using `set()` or `mset()` and can be retrieved by - `get()` or `mget()`. + Base class for Key-Value store implementations. Entries are added to store with ``set()`` or ``mset()``, and retrieved + with ``get()`` or ``mget()``. """ def set(self, key, value): From 83596175c3c80e59317833a23824e502962f65ec Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 00:27:41 -0700 Subject: [PATCH 57/63] Update store.py --- bagua/torch_api/contrib/utils/store.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 1119d347c..1720fd3b6 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -16,7 +16,7 @@ def set(self, key, value): pass def get(self, key) -> Optional[Union[str, bytes]]: - """Returns the value associated with key `key`, or None if the key doesn't exist.""" + """Returns the value associated with ``key``, or None if the key doesn't exist.""" pass # type: ignore def num_keys(self) -> int: @@ -35,7 +35,7 @@ def mset(self, mapping): def mget(self, keys) -> List[Optional[Union[str, bytes]]]: """ - Returns a list of values ordered identically to `keys`. + Retrieve each key's corresponding value and return them in a list with the same order as ``keys``. """ pass # type: ignore @@ -48,8 +48,7 @@ def status(self) -> bool: def shutdown(self): """ - Shutdown the current store. External store resources, for example, initialized redis servers, - will not be shut down by this method. + Shutdown the managed store instances. Unmanaged instances will not be killed. """ pass @@ -58,14 +57,10 @@ class ClusterStore(Store): """ Base class for distributed Key-Value stores. - This class implements client side sharding. It uses **xxHash** algorithm to compute the shard key by default, and can - accept customized hashing algorithms by passing `hash_fn` on initialization. - - key-value pairs are manually added to the cluster using `set()` or `mset()` and can be retrieved by - `get()` or `mget()`. + In ``ClusterStore``, entries will be sharded equally among multiple store instances based on their keys. Args: - stores(List[Store]): A list of stores in the cluster. + stores(List[Store]): A list of stores to shard entries on. hash_fn: Hash function to compute the shard key. Default is `xxh64`. A `hash_fn` accepts a `str` as input, and returns an `int` as output. From 58b8e647f01752392d731bdbff2f1e8861010cf7 Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 01:01:11 -0700 Subject: [PATCH 58/63] Update redis_store.py --- bagua/torch_api/contrib/utils/redis_store.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index f5c2999a7..f10dfe21b 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -42,9 +42,11 @@ class RedisStore(ClusterStore): A Redis-based distributed Key-Value store implementation, with ``set(...)`` and ``get(...)``` API exposed. Args: - hosts (List[Dict[str, str]]): A list of redis servers, defined by a list of dict containing server host and - port information. New Redis instances will be spawned if ``hosts=None``. - cluster_mode (bool): If ``True``, data is sharded across all Redis instances. Otherwise, data is routed to a specific server. + hosts (List[Dict[str, str]]): A list of redis servers, defined by a list of dict containing Redis host and + port information like ``[{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}]``. + A new Redis instance will be spawned on each node if ``hosts=None``. + cluster_mode (bool): If ``True``, data is sharded across all Redis instances. Otherwise, if there are ``m`` + Redis instances, the workers on the n-th node will use the ``n % m``-th Redis instance. capacity_per_node (int): Maximum memory limit in bytes when spawning new Redis instances. Old values will be evicted when the limit is reached. hash_fn: Hash function to determine which shard a key belongs to. ``hash_fn`` accepts a ``str`` and returns an ``int`` as output. From 34d7672caa195ad976cfd386a5230c00101810ae Mon Sep 17 00:00:00 2001 From: Xiangru Lian Date: Mon, 23 Aug 2021 01:24:41 -0700 Subject: [PATCH 59/63] Update redis_store.py --- bagua/torch_api/contrib/utils/redis_store.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index f10dfe21b..826fb14d1 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -51,7 +51,8 @@ class RedisStore(ClusterStore): hash_fn: Hash function to determine which shard a key belongs to. ``hash_fn`` accepts a ``str`` and returns an ``int`` as output. .. note:: - All Bagua jobs will share the same local Redis instance if ``hosts=None``. + All Bagua jobs will share the same local Redis instance if ``hosts=None``. The ``capacity_per_node`` only affects + newly spawned Redis instances, and has no effect on existing Redis instances. """ def __init__( From 24cd3c7d35108cdad3c60e235052f403baf8d96b Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 23 Aug 2021 17:22:30 +0800 Subject: [PATCH 60/63] update doc --- bagua/torch_api/contrib/cache_loader.py | 10 +++---- bagua/torch_api/contrib/cached_dataset.py | 16 +++++----- bagua/torch_api/contrib/utils/redis_store.py | 13 ++++---- bagua/torch_api/contrib/utils/store.py | 31 ++++++++------------ 4 files changed, 32 insertions(+), 38 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index dd37631e6..8993ebcc1 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -21,15 +21,15 @@ def __init__( **kwargs, ): """ - `CacheLoader` caches values calculated by an expensive function by theirs keys via :func:`CacheLoader.get` method, + Cache loader caches values calculated by an expensive function by theirs keys via :func:`get` method, so that the values can be retrieved faster next time. Internally, values are indexed by ``"{dataset_name}_{key}"`` and saved in a distributed Key-Value - store, where ``dataset_name`` is specified on initializing, and ``key`` is the argument in :func:`CacheLoader.get`. + store, where ``dataset_name`` is specified on initializing, and ``key`` is the argument in :func:`get`. - By default, `CacheLoader` uses :class:`RedisStore` as its backend distributed Key-Value store implementation. It + By default, cache loader uses :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` as its backend distributed Key-Value store implementation. It supports using a list of existing redis servers or spawning new redis servers. See also - :class:`bagua.torch_api.contrib.utils.redis_store.RedisStore`. Parameters for `RedisStore` can be provided here in + :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore`. Parameters for :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` can be provided here in ``**kwargs``. Args: @@ -53,7 +53,7 @@ def __init__( >>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000) .. note:: - ``CacheLoader``s with the same ``dataset_name`` will reuse and overwrite each other's cache. + Cache loaders with the same ``dataset_name`` will reuse and overwrite each other's cache. Use different ``dataset_name`` if this is not desired. """ diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 0fb48448e..7a8505ef0 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -14,16 +14,16 @@ def __init__( **kwargs, ): """ - `CachedDataset` wraps a PyTorch Dataset to cache its samples in memory, so that accessing these samples after the + Cached dataset wraps a PyTorch dataset to cache its samples in memory, so that accessing these samples after the first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading the dataset itself is slow, which could slow down the whole training process. Internally, the samples are indexed by a key ``"{dataset_name}_{index}"`` and saved in a distributed Key-Value - store, where ``dataset_name`` is specified when initializing the `CachedDataset`, and ``index`` is the index - of a specific sample (the argument of `__getitem__(...)` method in a PyTorch Dataset). + store, where ``dataset_name`` is specified when initializing the cached dataset, and ``index`` is the index + of a specific sample (the argument of :func:`__getitem__(...)` method in a PyTorch dataset). Args: - dataset: PyTorch Dataset to be wrapped. + dataset: PyTorch dataset to be wrapped. backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. dataset_name(str): Name of the dataset. Default ``""``. writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. @@ -36,9 +36,9 @@ def __init__( >>> dataloader = torch.utils.data.DataLoader(cached_dataset) .. note:: - `CachedDataset` is a special case of `CacheLoader`, and parameter `backend` and `writer_buffer_size` - in `CachedDataset` have the same meanings with those in `CacheLoader`. You can provide ``CacheLoader``'s - argument here in ``**kwargs``. See also :class:`bagua.torch_api.contrib.CacheLoader`. + Cached dataset is a special case of cache loader, and parameter ``backend`` and ``writer_buffer_size`` in + in initializing a cached dataset have the same meanings with those in initializing a cache loader. You can + provide the arguments for cache loader here in ``**kwargs``. See also :class:`~bagua.torch_api.contrib.cache_loader.CacheLoader`. """ @@ -51,7 +51,7 @@ def __init__( **kwargs, ) """ - The backend cache instance. + The backend cache loader instance. """ def __getitem__(self, item): diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index f5c2999a7..ab5f2e2a4 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -46,18 +46,17 @@ class RedisStore(ClusterStore): port information. New Redis instances will be spawned if ``hosts=None``. cluster_mode (bool): If ``True``, data is sharded across all Redis instances. Otherwise, data is routed to a specific server. capacity_per_node (int): Maximum memory limit in bytes when spawning new Redis instances. Old values will be evicted when the limit is reached. - hash_fn: Hash function to determine which shard a key belongs to. ``hash_fn`` accepts a ``str`` and returns an ``int`` as output. + Default is 100GB. .. note:: - All Bagua jobs will share the same local Redis instance if ``hosts=None``. + All Bagua jobs within the same node will share the same local Redis instance if ``hosts=None``. """ def __init__( self, hosts: Optional[List[Dict[str, str]]] = None, cluster_mode: bool = True, - capacity_per_node: int = 100_000_000_000, - hash_fn=None, + capacity_per_node: int = 107_374_182_400, ): if hosts is None: @@ -78,7 +77,7 @@ def __init__( store = _RedisStore(host=h["host"], port=h["port"]) stores.append(store) - super(RedisStore, self).__init__(stores, hash_fn) + super(RedisStore, self).__init__(stores) def _is_bootstrapped(): @@ -161,8 +160,8 @@ def num_keys(self) -> int: def clear(self): self.client.flushdb() - def mset(self, mapping: Dict[str, Union[str, bytes]]): - self.client.mset(mapping) + def mset(self, dictionary: Dict[str, Union[str, bytes]]): + self.client.mset(dictionary) def mget(self, keys: List[str]) -> List[Optional[Union[str, bytes]]]: return self.client.mget(keys) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 1720fd3b6..568889137 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -27,9 +27,9 @@ def clear(self): """Delete all keys in the current store.""" pass - def mset(self, mapping): + def mset(self, dictionary): """ - Set key/values based on a mapping. Mapping is a dictionary of key/value pairs. + Set multiple entries at once with a dictionary. Each key-value pair in the ``dictionary`` will be set. """ pass @@ -42,7 +42,7 @@ def mget(self, keys) -> List[Optional[Union[str, bytes]]]: def status(self) -> bool: """ - Returns the status of the current store. + Returns ``True`` if the current store is alive. """ pass # type: ignore @@ -61,31 +61,26 @@ class ClusterStore(Store): Args: stores(List[Store]): A list of stores to shard entries on. - hash_fn: Hash function to compute the shard key. Default is `xxh64`. A `hash_fn` accepts a `str` as - input, and returns an `int` as output. """ - def __init__(self, stores: List[Store], hash_fn=None): + def __init__(self, stores: List[Store]): self.stores = stores self.num_stores = len(stores) - if hash_fn is None: - import xxhash + import xxhash - def xxh64(x): - return xxhash.xxh64(x).intdigest() + def xxh64(x): + return xxhash.xxh64(x).intdigest() - hash_fn = xxh64 + self.hash_fn = xxh64 - self.hash_fn = hash_fn - - def _hash_key(self, key) -> int: + def _hash_key(self, key: str) -> int: hash_code = self.hash_fn(key) return hash_code % self.num_stores - def route(self, key) -> Store: + def route(self, key: str) -> Store: return ( self.stores[self._hash_key(key)] if self.num_stores > 1 else self.stores[0] ) @@ -109,12 +104,12 @@ def clear(self): for store in self.stores: store.clear() - def mset(self, mapping: Dict[str, Union[str, bytes]]): + def mset(self, dictionary: Dict[str, Union[str, bytes]]): if self.num_stores == 1: - return self.stores[0].mset(mapping) + return self.stores[0].mset(dictionary) route_table = {} - for k, v in mapping.items(): + for k, v in dictionary.items(): sid = self._hash_key(k) m = route_table.get(sid, defaultdict(dict)) m[k] = v From c94f67c90d3ac8e4f6b9b852d4a4e6704f24eacb Mon Sep 17 00:00:00 2001 From: ritaw Date: Mon, 23 Aug 2021 18:22:17 +0800 Subject: [PATCH 61/63] . --- bagua/torch_api/contrib/cache_loader.py | 4 +++- bagua/torch_api/contrib/cached_dataset.py | 4 ++-- bagua/torch_api/contrib/utils/store.py | 8 ++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 8993ebcc1..742774cdb 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -1,5 +1,7 @@ import pickle from collections import defaultdict +from typing import Callable + __all__ = ["CacheLoader"] @@ -70,7 +72,7 @@ def __init__( self.fetcher = BatchFetcher(self.store, 1, writer_buffer_size) - def get(self, key, load_fn): + def get(self, key: str, load_fn: Callable[[str], None]): """ Returns the value associated with key in cache, use ``load_fn`` to create the entry if the key does not exist in the cache. ``load_fn`` is a function taking ``key`` as its argument, and returning corresponding value to diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 7a8505ef0..10cab59de 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -36,8 +36,8 @@ def __init__( >>> dataloader = torch.utils.data.DataLoader(cached_dataset) .. note:: - Cached dataset is a special case of cache loader, and parameter ``backend`` and ``writer_buffer_size`` in - in initializing a cached dataset have the same meanings with those in initializing a cache loader. You can + Cached dataset is a special case of cache loader. Parameter ``backend`` and ``writer_buffer_size`` in + initializing a cached dataset have the same meanings as those in initializing a cache loader. You can provide the arguments for cache loader here in ``**kwargs``. See also :class:`~bagua.torch_api.contrib.cache_loader.CacheLoader`. """ diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 568889137..e8d2de75c 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -11,11 +11,11 @@ class Store: with ``get()`` or ``mget()``. """ - def set(self, key, value): + def set(self, key: str, value: Union[str, bytes]): """Set a key-value pair.""" pass - def get(self, key) -> Optional[Union[str, bytes]]: + def get(self, key: str) -> Optional[Union[str, bytes]]: """Returns the value associated with ``key``, or None if the key doesn't exist.""" pass # type: ignore @@ -27,13 +27,13 @@ def clear(self): """Delete all keys in the current store.""" pass - def mset(self, dictionary): + def mset(self, dictionary: Dict[str, Union[str, bytes]]): """ Set multiple entries at once with a dictionary. Each key-value pair in the ``dictionary`` will be set. """ pass - def mget(self, keys) -> List[Optional[Union[str, bytes]]]: + def mget(self, keys: List[str]) -> List[Optional[Union[str, bytes]]]: """ Retrieve each key's corresponding value and return them in a list with the same order as ``keys``. """ From 1a97a30b3930cc5de96619984ab08901b1001d41 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 24 Aug 2021 10:29:24 +0800 Subject: [PATCH 62/63] format --- bagua/torch_api/contrib/cache_loader.py | 17 ++++++++--------- bagua/torch_api/contrib/cached_dataset.py | 11 ++++++----- bagua/torch_api/contrib/utils/redis_store.py | 12 ++++++------ bagua/torch_api/contrib/utils/store.py | 8 ++++---- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/bagua/torch_api/contrib/cache_loader.py b/bagua/torch_api/contrib/cache_loader.py index 742774cdb..bb35913fa 100644 --- a/bagua/torch_api/contrib/cache_loader.py +++ b/bagua/torch_api/contrib/cache_loader.py @@ -23,21 +23,20 @@ def __init__( **kwargs, ): """ - Cache loader caches values calculated by an expensive function by theirs keys via :func:`get` method, + Cache loader caches values calculated by an expensive function by theirs keys via :meth:`get`, so that the values can be retrieved faster next time. - Internally, values are indexed by ``"{dataset_name}_{key}"`` and saved in a distributed Key-Value - store, where ``dataset_name`` is specified on initializing, and ``key`` is the argument in :func:`get`. + Internally, values are indexed by ``"{dataset_name}_{key}"`` and saved in a distributed key-value + store, where ``dataset_name`` is specified on initializing, and ``key`` is the argument in :meth:`get`. - By default, cache loader uses :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` as its backend distributed Key-Value store implementation. It - supports using a list of existing redis servers or spawning new redis servers. See also - :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore`. Parameters for :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` can be provided here in + By default, cache loader uses :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` as its backend distributed key-value store implementation. It + supports using a list of existing redis servers or spawning new redis servers. Parameters for :class:`~bagua.torch_api.contrib.utils.redis_store.RedisStore` can be provided here in ``**kwargs``. Args: - backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. + backend(str): Backend distributed key-value store implementation. Can be ``"redis"``. dataset_name(str): Name of the dataset. Default ``""``. - writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. + writer_buffer_size(int): Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput. Example:: @@ -74,7 +73,7 @@ def __init__( def get(self, key: str, load_fn: Callable[[str], None]): """ - Returns the value associated with key in cache, use ``load_fn`` to create the entry if the key does not exist + Returns the value associated with ``key`` in cache, use ``load_fn`` to create the entry if the key does not exist in the cache. ``load_fn`` is a function taking ``key`` as its argument, and returning corresponding value to be cached. """ diff --git a/bagua/torch_api/contrib/cached_dataset.py b/bagua/torch_api/contrib/cached_dataset.py index 10cab59de..2dc20f6c2 100644 --- a/bagua/torch_api/contrib/cached_dataset.py +++ b/bagua/torch_api/contrib/cached_dataset.py @@ -14,19 +14,20 @@ def __init__( **kwargs, ): """ - Cached dataset wraps a PyTorch dataset to cache its samples in memory, so that accessing these samples after the + Cached dataset wraps a `PyTorch dataset `_ + to cache its samples in memory, so that accessing these samples after the first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading the dataset itself is slow, which could slow down the whole training process. - Internally, the samples are indexed by a key ``"{dataset_name}_{index}"`` and saved in a distributed Key-Value + Internally, the samples are indexed by a string key ``"{dataset_name}_{index}"`` and saved in a distributed key-value store, where ``dataset_name`` is specified when initializing the cached dataset, and ``index`` is the index - of a specific sample (the argument of :func:`__getitem__(...)` method in a PyTorch dataset). + of a specific sample (the argument of :meth:`__getitem__` method in a PyTorch dataset). Args: dataset: PyTorch dataset to be wrapped. - backend(str): Backend distributed Key-Value store implementation. Can be ``"redis"``. + backend(str): Backend distributed key-value store implementation. Can be ``redis``. dataset_name(str): Name of the dataset. Default ``""``. - writer_buffer_size(int): Number of samples to collect before writing to the backend Key-Value store. + writer_buffer_size(int): Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput. Example:: diff --git a/bagua/torch_api/contrib/utils/redis_store.py b/bagua/torch_api/contrib/utils/redis_store.py index c3437d576..5627a2588 100644 --- a/bagua/torch_api/contrib/utils/redis_store.py +++ b/bagua/torch_api/contrib/utils/redis_store.py @@ -39,20 +39,20 @@ class RedisStore(ClusterStore): """ - A Redis-based distributed Key-Value store implementation, with ``set(...)`` and ``get(...)``` API exposed. + A Redis-based distributed key-value store implementation, with :meth:`set` and :meth:`get` API exposed. Args: hosts (List[Dict[str, str]]): A list of redis servers, defined by a list of dict containing Redis host and port information like ``[{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}]``. - A new Redis instance will be spawned on each node if ``hosts=None``. + A new Redis instance will be spawned on each node if ``hosts`` is ``None``. cluster_mode (bool): If ``True``, data is sharded across all Redis instances. Otherwise, if there are ``m`` - Redis instances, the workers on the n-th node will use the ``n % m``-th Redis instance. + Redis instances, the workers on the ``n``-th node will use the ``n % m``-th Redis instance. capacity_per_node (int): Maximum memory limit in bytes when spawning new Redis instances. Old values will be evicted when the limit is reached. - Default is 100GB. + Default is ``100GB``. .. note:: - All Bagua jobs within the same node will share the same local Redis instance if ``hosts=None``. The ``capacity_per_node`` only affects - newly spawned Redis instances, and has no effect on existing Redis instances. + All Bagua jobs within the same node will share the same local Redis instance if ``hosts`` is ``None``. The ``capacity_per_node`` only affects + newly spawned Redis instances, and has no effect on existing ones. """ def __init__( diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index e8d2de75c..8b14dc2b6 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -7,8 +7,8 @@ class Store: """ - Base class for Key-Value store implementations. Entries are added to store with ``set()`` or ``mset()``, and retrieved - with ``get()`` or ``mget()``. + Base class for key-value store implementations. Entries are added to store with :meth:`set` or :meth:`mset`, and retrieved + with :meth:`get` or :meth:`mget`. """ def set(self, key: str, value: Union[str, bytes]): @@ -55,9 +55,9 @@ def shutdown(self): class ClusterStore(Store): """ - Base class for distributed Key-Value stores. + Base class for distributed key-value stores. - In ``ClusterStore``, entries will be sharded equally among multiple store instances based on their keys. + In cluster store, entries will be sharded equally among multiple store instances based on their keys. Args: stores(List[Store]): A list of stores to shard entries on. From db0ed142669b567f29af2373ea50d6885002a8f1 Mon Sep 17 00:00:00 2001 From: ritaw Date: Tue, 24 Aug 2021 10:38:31 +0800 Subject: [PATCH 63/63] . --- bagua/torch_api/contrib/utils/store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bagua/torch_api/contrib/utils/store.py b/bagua/torch_api/contrib/utils/store.py index 8b14dc2b6..be243c521 100644 --- a/bagua/torch_api/contrib/utils/store.py +++ b/bagua/torch_api/contrib/utils/store.py @@ -77,7 +77,7 @@ def xxh64(x): self.hash_fn = xxh64 def _hash_key(self, key: str) -> int: - hash_code = self.hash_fn(key) + hash_code = self.hash_fn(key.encode()) return hash_code % self.num_stores def route(self, key: str) -> Store: