Skip to content

Commit

Permalink
support video mixin; support multi-output
Browse files Browse the repository at this point in the history
Signed-off-by: jie.hou <jie.hou@zilliz.com>
  • Loading branch information
jie.hou committed Mar 3, 2022
1 parent fa7a042 commit 89871c5
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 12 deletions.
22 changes: 14 additions & 8 deletions towhee/functional/data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

from typing import Iterable, Iterator
from random import random, sample, shuffle
import concurrent.futures

from towhee.hparam import param_scope
from towhee.functional.option import Option, Some, Empty

from towhee.functional.mixins.data_source import DataSourceMixin
from towhee.functional.mixins.dispatcher import DispatcherMixin
from towhee.functional.mixins.parallel import ParallelMixin
from towhee.functional.mixins.computer_vision import ComputerVisionMixin


def _private_wrapper(func):
Expand All @@ -34,8 +34,8 @@ def wrapper(self, *arg, **kws):
return wrapper


class DataCollection(Iterable, DataSourceMixin, DispatcherMixin,
ParallelMixin):
class DataCollection(Iterable, DataSourceMixin, DispatcherMixin, ParallelMixin,
ComputerVisionMixin):
"""
DataCollection is a quick assambler for chained data processing operators.
Expand Down Expand Up @@ -324,7 +324,7 @@ def inner(data):
return inner(self._iterable)

@_private_wrapper
def map(self, unary_op):
def map(self, *arg):
"""
apply operator to data collection
Expand All @@ -335,10 +335,16 @@ def map(self, unary_op):
"""

# return map(unary_op, self._iterable)
if hasattr(self, '_executor') and isinstance(
self._executor, concurrent.futures.ThreadPoolExecutor):
# mmap
if len(arg) > 1:
return self.mmap(*arg)
unary_op = arg[0]

# pmap
if self.get_executor() is not None:
return self.pmap(unary_op, executor=self._executor)

#map
def inner(x):
if isinstance(x, Option):
return x.map(unary_op)
Expand All @@ -348,7 +354,7 @@ def inner(x):
return map(inner, self._iterable)

@_private_wrapper
def zip(self, other):
def zip(self, *others):
"""
combine two data collections
>>> dc1 = DataCollection([1,2,3,4])
Expand All @@ -357,7 +363,7 @@ def zip(self, other):
>>> list(dc3)
[(1, 2), (2, 3), (3, 4), (4, 5)]
"""
return zip(self, other)
return zip(self, *others)

@_private_wrapper
def filter(self, unary_op, drop_empty=False):
Expand Down
34 changes: 34 additions & 0 deletions towhee/functional/mixins/computer_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class ComputerVisionMixin:
"""
Mixin for computer vision problems.
Examples:
>>> from towhee.functional import DataCollection
>>> DataCollection.from_camera(1).imshow()
"""

def imshow(self, title="image"):
import cv2 # pylint: disable=import-outside-toplevel
for im in self:
cv2.imshow(title, im)
cv2.waitKey(1)


if __name__ == '__main__': # pylint: disable=inconsistent-quotes
import doctest
doctest.testmod(verbose=False)
24 changes: 22 additions & 2 deletions towhee/functional/mixins/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from glob import glob
from towhee.utils.ndarray_utils import from_zip


class DataSourceMixin:
Expand All @@ -22,9 +21,30 @@ class DataSourceMixin:
"""

@classmethod
def glob(cls, pattern):
def from_glob(cls, pattern):
"""
generate a file list with `pattern`
"""
return cls.stream(glob(pattern))

@classmethod
def from_zip(cls, zip_path, pattern):
from towhee.utils.ndarray_utils import from_zip # pylint: disable=import-outside-toplevel
return cls.stream(from_zip(zip_path, pattern))

@classmethod
def from_camera(cls, device_id=0, limit=-1):
"""
read images from a camera.
"""
import cv2 # pylint: disable=import-outside-toplevel
cnt = limit
def inner():
nonlocal cnt
cap = cv2.VideoCapture(device_id)
while cnt != 0:
retval, im = cap.read()
if retval:
yield im
cnt -= 1
return cls.stream(inner())
87 changes: 85 additions & 2 deletions towhee/functional/mixins/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,21 @@ def set_parallel(self, num_worker=None, executor=None):
... .map(lambda x: stage_1_thread_set.add(threading.current_thread().ident))
... .map(lambda x: stage_2_thread_set.add(threading.current_thread().ident)).to_list()
... )
>>> len(stage_2_thread_set)
3
>>> len(stage_2_thread_set)>1
True
"""
if executor is not None:
self._executor = executor
if num_worker is not None:
self._executor = concurrent.futures.ThreadPoolExecutor(num_worker)
return self

def get_executor(self):
if hasattr(self, '_executor') and isinstance(
self._executor, concurrent.futures.ThreadPoolExecutor):
return self._executor
return None

def parallel(self, num_worker):
executor = concurrent.futures.ThreadPoolExecutor(num_worker)
queue = Queue(maxsize=num_worker)
Expand Down Expand Up @@ -147,6 +153,83 @@ def inner():

return self.factory(inner())

def mmap(self, *arg):
"""
apply multiple unary_op to data collection.
Examples:
1. using mmap
>>> from towhee.functional import DataCollection
>>> dc = DataCollection.range(5).stream()
>>> a, b = dc.mmap(lambda x: x+1, lambda x: x*2)
>>> c = a.map(lambda x: x+1)
>>> c.zip(b).to_list()
[(2, 0), (3, 2), (4, 4), (5, 6), (6, 8)]
2. using map instead of mmap
>>> from towhee.functional import DataCollection
>>> dc = DataCollection.range(5).stream()
>>> a, b, c = dc.map(lambda x: x+1, lambda x: x*2, lambda x: int(x/2))
>>> d = a.map(lambda x: x+1)
>>> d.zip(b, c).to_list()
[(2, 0, 0), (3, 2, 0), (4, 4, 1), (5, 6, 1), (6, 8, 2)]
3. dag execution
>>> dc = DataCollection.range(5).stream()
>>> a, b, c = dc.map(lambda x: x+1, lambda x: x*2, lambda x: int(x/2))
>>> d = a.map(lambda x: x+1)
>>> d.zip(b, c).map(lambda x: x[0]+x[1]+x[2]).to_list()
[2, 5, 9, 12, 16]
"""
executor = self.get_executor()
if executor is None:
executor = concurrent.futures.ThreadPoolExecutor(len(arg))
num_worker = 1
queues = [Queue(maxsize=num_worker) for _ in arg]
loop = asyncio.new_event_loop()
flag = True

def make_task(x, unary_op):

def task_wrapper():
if isinstance(x, Option):
return x.map(unary_op)
else:
return unary_op(x)

return task_wrapper

async def worker():
buffs = [[] for _ in arg]
for x in self:
for i in range(len(arg)):
queue = queues[i]
buff = buffs[i]
if len(buff) == num_worker:
queue.put(await buff.pop(0))
buff.append(
loop.run_in_executor(executor, make_task(x, arg[i])))
while sum([len(buff) for buff in buffs]) > 0:
for i in range(len(arg)):
queue = queues[i]
buff = buffs[i]
queue.put(await buff.pop(0))
nonlocal flag
flag = False

def worker_wrapper():
loop.run_until_complete(worker())

executor.submit(worker_wrapper)

def inner(queue):
nonlocal flag
while flag or not queue.empty():
yield queue.get()

retval = [inner(queue) for queue in queues]
return [self.factory(x) for x in retval]


if __name__ == '__main__': # pylint: disable=inconsistent-quotes
import doctest
Expand Down

0 comments on commit 89871c5

Please sign in to comment.