-
Notifications
You must be signed in to change notification settings - Fork 129
/
webvid.py
405 lines (354 loc) · 20.9 KB
/
webvid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""A streaming WebVid dataset."""
import os
from time import sleep
from typing import Any, Optional
from streaming.base import StreamingDataset
from streaming.base.dataset import TICK, _Iterator
from streaming.base.storage import download_file
class StreamingInsideWebVid(StreamingDataset):
"""Streaming WebVid dataset.
Videos are stored "inside" the shards, as is typically done.
Args:
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str, optional): Local working directory to download shards to. This is where shards
are cached while they are being used. Uses a temp directory if not set.
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
``False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples to download per worker in advance
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value gets derived using per device batch size and number of
canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``.
Defaults to ``None``.
cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache.
Before downloading a shard, the least recently used resident shard(s) may be evicted
(deleted from the local cache) in order to stay under the limit. Set to ``None`` to
disable shard eviction. Defaults to ``None``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). Defaults to ``None``, which is interpreted as 64 times the number
of nodes of the initial run.
.. note::
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
"""
def get_item(self, idx: int) -> Any:
"""Get the sample at the index.
Args:
idx (int): Sample index.
Returns:
Any: The sample.
"""
obj = super().get_item(idx)
# Processing goes here.
return obj
class StreamingOutsideGIWebVid(StreamingDataset):
"""Streaming WebVid dataset.
Videos are stored "outside" the shards, as a file per video. The extra download happens in
get_item ("GI"), when samples are requested by the dataloader.
Args:
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str, optional): Local working directory to download shards to. This is where shards
are cached while they are being used. Uses a temp directory if not set.
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
``False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples to download per worker in advance
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value gets derived using per device batch size and number of
canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``.
Defaults to ``None``.
cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache.
Before downloading a shard, the least recently used resident shard(s) may be evicted
(deleted from the local cache) in order to stay under the limit. Set to ``None`` to
disable shard eviction. Defaults to ``None``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). Defaults to ``None``, which is interpreted as 64 times the number
of nodes of the initial run.
.. note::
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
extra_local (str, optional): Base destination of extra local sample downloads.
extra_remote (str, optional): Base source of extra remote sample downloads.
"""
def __init__(self,
*,
remote: Optional[str] = None,
local: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
cache_limit: Optional[int] = None,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1s',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
extra_local: Optional[str] = None,
extra_remote: Optional[str] = None) -> None:
super().__init__(remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
shuffle=shuffle,
shuffle_algo=shuffle_algo,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size)
# Videos are stored outside of their shards here.
self.download_timeout = download_timeout
self.extra_local = extra_local
self.extra_remote = extra_remote
def get_item(self, idx: int) -> Any:
"""Get the sample at the index.
Args:
idx (int): Sample index.
Returns:
Any: The sample.
"""
obj = super().get_item(idx)
if self.extra_local and self.extra_remote:
rel_path = obj['content_path']
local = os.path.join(self.extra_local, rel_path)
remote = os.path.join(self.extra_remote, rel_path)
if not os.path.exists(local):
download_file(remote, local, self.download_timeout)
with open(local, 'rb') as fp:
content = fp.read()
obj['content'] = content
# Processing goes here.
return obj
class StreamingOutsideDTWebVid(StreamingDataset):
"""Streaming WebVid dataset.
Videos are stored "outside" the shards, as a file per video. The extra download happens in
_download_thread ("DT"), when the download thread prefetches the sample.
Args:
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str, optional): Local working directory to download shards to. This is where shards
are cached while they are being used. Uses a temp directory if not set.
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
``False``.
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
streams. If ``None``, takes its value from the total number of underlying samples.
Provide this field if you are weighting streams relatively to target a larger or
smaller epoch size. Defaults to ``None``.
predownload (int, optional): Target number of samples to download per worker in advance
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value gets derived using per device batch size and number of
canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``.
Defaults to ``None``.
cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache.
Before downloading a shard, the least recently used resident shard(s) may be evicted
(deleted from the local cache) in order to stay under the limit. Set to ``None`` to
disable shard eviction. Defaults to ``None``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). Defaults to ``None``, which is interpreted as 64 times the number
of nodes of the initial run.
.. note::
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
extra_local (str, optional): Base destination of extra local sample downloads.
extra_remote (str, optional): Base source of extra remote sample downloads.
"""
def __init__(self,
*,
remote: Optional[str] = None,
local: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
cache_limit: Optional[int] = None,
partition_algo: str = 'orig',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1s',
shuffle_seed: int = 9176,
shuffle_block_size: int = 1 << 18,
extra_local: Optional[str] = None,
extra_remote: Optional[str] = None) -> None:
super().__init__(remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
epoch_size=epoch_size,
predownload=predownload,
cache_limit=cache_limit,
partition_algo=partition_algo,
num_canonical_nodes=num_canonical_nodes,
batch_size=batch_size,
shuffle=shuffle,
shuffle_algo=shuffle_algo,
shuffle_seed=shuffle_seed,
shuffle_block_size=shuffle_block_size)
# Videos are stored outside of their shards here.
self.download_timeout = download_timeout
self.extra_local = extra_local
self.extra_remote = extra_remote
def get_item(self, idx: int) -> Any:
"""Get the sample at the index.
Args:
idx (int): Sample index.
Returns:
Any: The sample.
"""
obj = super().get_item(idx)
if self.extra_local and self.extra_remote:
rel_path = obj['content_path']
local = os.path.join(self.extra_local, rel_path)
remote = os.path.join(self.extra_remote, rel_path)
if not os.path.exists(local):
download_file(remote, local, self.download_timeout)
with open(local, 'rb') as fp:
content = fp.read()
obj['content'] = content
# Processing goes here.
return obj
def _download_thread(self, it: _Iterator) -> None:
"""Download the relevant shards in the background while we are being iterated.
This thread is started at the beginning of each epoch, and exits either when out of samples
or when a new epoch is started, calling exit_threads() on its state (only one epoch is
valid at a time).
Each worker has its own download thread, which iterates ahead of the ready thread and yield
loop.
Args:
it (_Iterator): State of __iter__.
"""
# Download loop.
while True:
# If we've started a new epoch early (__iter__ was called again), exit this thread
# because there can only be one epoch at once.
if it.should_exit():
break
# If we're out of samples this epoch, exit this thread because we are done downloading.
if it.prepare_index == it.total:
break
# If we are requested to only pre-download so many samples, if we have as many or more
# downloaded already, we wait and check again later.
if self.predownload is not None:
samples_ahead = it.prepare_index - it.yield_index
if self.predownload <= samples_ahead:
sleep(TICK)
continue
# If we hit -1, we skip.
sample_id = it.sample_ids[it.prepare_index]
if sample_id == -1:
it.prepare_index += 1
continue
# Download and decompress the shard for this sample, if not already done.
shard_id, _ = self.spanner[sample_id]
self.prepare_shard(shard_id, False)
# Predownload the sample's extra data.
obj = super().get_item(sample_id)
if self.extra_local and self.extra_remote:
rel_path = obj['content_path']
local = os.path.join(self.extra_local, rel_path)
remote = os.path.join(self.extra_remote, rel_path)
if not os.path.exists(local):
download_file(remote, local, self.download_timeout)
# Step forward one sample.
it.prepare_index += 1
# Note that we exited.
it.on_exit()