Skip to content

Commit

Permalink
V0.1 test store update (#58)
Browse files Browse the repository at this point in the history
* updateed test_store.py

* fixed unpicklable store bug

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
  • Loading branch information
ysqyang and ysqyang committed Sep 23, 2020
1 parent 33641ce commit ef485fa
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion maro/rl/storage/column_based_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def sample_by_keys(self, keys: Sequence, sizes: Sequence, replace: bool = True):
return indexes, self.get(indexes)

def dumps(self):
return clone(self._store)
return clone(dict(self._store))

def get_by_key(self, key):
return self._store[key]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_update(self):
store = ColumnBasedStore(capacity=5, overwrite_type=OverwriteType.ROLLING)
store.put({"a": [1, 2, 3, 4, 5], "b": [6, 7, 8, 9, 10], "c": [11, 12, 13, 14, 15]})
store.update([0, 3], {"a": [-1, -4], "c": [-11, -14]})
actual = store.take()
actual = store.dumps()
expected = {"a": [-1, 2, 3, -4, 5], "b": [6, 7, 8, 9, 10], "c": [-11, 12, 13, -14, 15]}
self.assertEqual(actual, expected, msg=f"expected store content = {expected}, got {actual}")

Expand All @@ -54,7 +54,7 @@ def test_put_with_rolling_overwrite(self):
indexes = store.put({"a": [10, 11, 12, 13], "b": [14, 15, 16, 17], "c": [18, 19, 20, 21]})
expected = [-2, -1, 0, 1]
self.assertEqual(indexes, expected, msg=f"expected indexes = {expected}, got {indexes}")
actual = store.take()
actual = store.dumps()
expected = {"a": [12, 13, 3, 10, 11], "b": [16, 17, 6, 14, 15], "c": [20, 21, 9, 18, 19]}
self.assertEqual(actual, expected, msg=f"expected store content = {expected}, got {actual}")

Expand Down

0 comments on commit ef485fa

Please sign in to comment.