Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
rainyfly committed Feb 19, 2024
1 parent 031b12f commit d1fcba5
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions llm/fastdeploy_llm/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,6 @@ def prepare_model(self):
self.model.start()

def execute(self, req_dict):
if self.model is None:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Model is not ready"
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
# 1. validate the deserializing process
task = Task()
try:
Expand All @@ -138,20 +130,19 @@ def execute(self, req_dict):
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
return task

# 3. check if exists task id conflict
if task.task_id is None:
task.task_id = str(uuid.uuid4())
request_start_time_dict[task.task_id] = request_start_time
if task.task_id in self.response_handler:
if task.task_id in event_dict:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Task id conflict with {}.".format(task.task_id)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
return None

# 4. validate the parameters in task
try:
Expand All @@ -163,7 +154,7 @@ def execute(self, req_dict):
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
return task

# 5. check if the requests queue is full
if self.model.requests_queue.qsize() > self.config.max_queue_num:
Expand All @@ -173,7 +164,7 @@ def execute(self, req_dict):
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
return task

# 6. check if the prefix embedding is exist
if self.config.is_ptuning and task.model_id is not None:
Expand All @@ -182,7 +173,7 @@ def execute(self, req_dict):
"task_prompt_embeddings.npy")
if not os.path.exists(np_file_path):
response_dict[req_dict['req_id']] = error_msg
return
return task

# 7. Add task to requests queue
task.call_back_func = stream_call_back
Expand All @@ -202,7 +193,7 @@ def execute(self, req_dict):
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
return task

except Exception as e:
error_type = ErrorType.Query
Expand All @@ -211,7 +202,7 @@ def execute(self, req_dict):
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
response_dict[req_dict['req_id']] = error_msg
return
return task

return task

Expand All @@ -231,14 +222,28 @@ async def inference(self, request_in: Request):
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
if self.model is None:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Model is not ready"
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
task = self.execute(input_dict)
if task is None: # task id conflict
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Task id conflict"
error_msg = error_format.format(error_type.name, error_code.name, error_info)
raise HTTPException(status_code=400, detail=error_msg)

event_dict[task.task_id] = asyncio.Event()
try:
await asyncio.wait_for(event_dict[task.task_id].wait(), self.wait_time_out)
except:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Timeout for getting inference result."
error_info = "Timeout for getting inference result, task={}".format(task)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
Expand Down Expand Up @@ -270,7 +275,8 @@ async def watch_result():
event_dict[task_id].set()

for task_id in response_checked_dict:
del response_dict[task_id]
if task_id in response_dict:
del response_dict[task_id]


model_dir = os.getenv("MODEL_DIR", None)
Expand Down

0 comments on commit d1fcba5

Please sign in to comment.