Skip to content

Commit

Permalink
feat: add support for maximum concurrency of /api/v1/videos
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin.zhang committed Apr 16, 2024
1 parent 414bcb0 commit abe12ab
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 2 deletions.
57 changes: 57 additions & 0 deletions app/controllers/manager/base_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import threading
from typing import Callable, Any, Dict


class TaskManager:
def __init__(self, max_concurrent_tasks: int):
self.max_concurrent_tasks = max_concurrent_tasks
self.current_tasks = 0
self.lock = threading.Lock()
self.queue = self.create_queue()

def create_queue(self):
raise NotImplementedError()

def add_task(self, func: Callable, *args: Any, **kwargs: Any):
with self.lock:
if self.current_tasks < self.max_concurrent_tasks:
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
self.execute_task(func, *args, **kwargs)
else:
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}")
self.enqueue({"func": func, "args": args, "kwargs": kwargs})

def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs)
thread.start()

def run_task(self, func: Callable, *args: Any, **kwargs: Any):
try:
with self.lock:
self.current_tasks += 1
func(*args, **kwargs) # 在这里调用函数,传递*args和**kwargs
finally:
self.task_done()

def check_queue(self):
with self.lock:
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty():
task_info = self.dequeue()
func = task_info['func']
args = task_info.get('args', ())
kwargs = task_info.get('kwargs', {})
self.execute_task(func, *args, **kwargs)

def task_done(self):
with self.lock:
self.current_tasks -= 1
self.check_queue()

def enqueue(self, task: Dict):
raise NotImplementedError()

def dequeue(self):
raise NotImplementedError()

def is_queue_empty(self):
raise NotImplementedError()
18 changes: 18 additions & 0 deletions app/controllers/manager/memory_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from queue import Queue
from typing import Dict

from app.controllers.manager.base_manager import TaskManager


class InMemoryTaskManager(TaskManager):
def create_queue(self):
return Queue()

def enqueue(self, task: Dict):
self.queue.put(task)

def dequeue(self):
return self.queue.get()

def is_queue_empty(self):
return self.queue.empty()
48 changes: 48 additions & 0 deletions app/controllers/manager/redis_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from typing import Dict

import redis

from app.controllers.manager.base_manager import TaskManager
from app.models.schema import VideoParams
from app.services import task as tm

FUNC_MAP = {
'start': tm.start,
# 'start_test': tm.start_test
}


class RedisTaskManager(TaskManager):
def __init__(self, max_concurrent_tasks: int, redis_url: str):
self.redis_client = redis.Redis.from_url(redis_url)
super().__init__(max_concurrent_tasks)

def create_queue(self):
return "task_queue"

def enqueue(self, task: Dict):
task_with_serializable_params = task.copy()

if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams):
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict()

# 将函数对象转换为其名称
task_with_serializable_params['func'] = task['func'].__name__
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))

def dequeue(self):
task_json = self.redis_client.lpop(self.queue)
if task_json:
task_info = json.loads(task_json)
# 将函数名称转换回函数对象
task_info['func'] = FUNC_MAP[task_info['func']]

if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict):
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params'])

return task_info
return None

def is_queue_empty(self):
return self.redis_client.llen(self.queue) == 0
34 changes: 33 additions & 1 deletion app/controllers/v1/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from app.config import config
from app.controllers import base
from app.controllers.manager.memory_manager import InMemoryTaskManager
from app.controllers.manager.redis_manager import RedisTaskManager
from app.controllers.v1.base import new_router
from app.models.exception import HttpException
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \
Expand All @@ -22,6 +24,35 @@
# router = new_router(dependencies=[Depends(base.verify_token)])
router = new_router()

_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5)

redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}"
# 根据配置选择合适的任务管理器
if _enable_redis:
task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url)
else:
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks)

# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video")
# async def create_video_test(request: Request, body: TaskVideoRequest):
# task_id = utils.get_uuid()
# request_id = base.get_task_id(request)
# try:
# task = {
# "task_id": task_id,
# "request_id": request_id,
# "params": body.dict(),
# }
# task_manager.add_task(tm.start_test, task_id=task_id, params=body)
# return utils.get_response(200, task)
# except ValueError as e:
# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")


@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
Expand All @@ -34,7 +65,8 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
"params": body.dict(),
}
sm.state.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body)
# background_tasks.add_task(tm.start, task_id=task_id, params=body)
task_manager.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task)
except ValueError as e:
Expand Down
2 changes: 1 addition & 1 deletion app/models/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MaterialInfo:
# ]


class VideoParams:
class VideoParams(BaseModel):
"""
{
"video_subject": "",
Expand Down
6 changes: 6 additions & 0 deletions app/services/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,9 @@ def start(task_id, params: VideoParams):
}
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs


# def start_test(task_id, params: VideoParams):
# print(f"start task {task_id} \n")
# time.sleep(5)
# print(f"task {task_id} finished \n")
7 changes: 7 additions & 0 deletions config.example.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
listen_host = "127.0.0.1"
listen_port = 8502

[app]
# Pexels API Key
# Register at https://www.pexels.com/api/ to get your API key.
Expand Down Expand Up @@ -134,6 +137,10 @@
redis_host = "localhost"
redis_port = 6379
redis_db = 0
redis_password = ""

# 文生视频时的最大并发任务数
max_concurrent_tasks = 5

[whisper]
# Only effective when subtitle_provider is "whisper"
Expand Down

0 comments on commit abe12ab

Please sign in to comment.