-
Notifications
You must be signed in to change notification settings - Fork 453
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Other] Hub feature: add model hub management for fastdeploy (#453)
* add hub tool for fastdeploy * fix format * refactor code * remove EasyDict Co-authored-by: Jason <jiangjiajun@baidu.com>
- Loading branch information
1 parent
c11f4ba
commit 7e64f40
Showing
6 changed files
with
316 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import hashlib | ||
import os | ||
import time | ||
import json | ||
import uuid | ||
import yaml | ||
|
||
import fastdeploy.utils.hub_env as hubenv | ||
|
||
|
||
class HubConfig: | ||
''' | ||
FastDeploy model management configuration class. | ||
''' | ||
|
||
def __init__(self): | ||
self._initialize() | ||
self.file = os.path.join(hubenv.CONF_HOME, 'config.yaml') | ||
|
||
if not os.path.exists(self.file): | ||
self.flush() | ||
return | ||
|
||
with open(self.file, 'r') as file: | ||
try: | ||
cfg = yaml.load(file, Loader=yaml.FullLoader) | ||
self.data.update(cfg) | ||
except: | ||
... | ||
|
||
def _initialize(self): | ||
# Set default configuration values. | ||
self.data = {} | ||
self.data['server'] = 'http://paddlepaddle.org.cn/paddlehub' | ||
|
||
def reset(self): | ||
'''Reset configuration to default.''' | ||
self._initialize() | ||
self.flush() | ||
|
||
@property | ||
def server(self): | ||
'''Model server url.''' | ||
return self.data['server'] | ||
|
||
@server.setter | ||
def server(self, url: str): | ||
self.data['server'] = url | ||
self.flush() | ||
|
||
def flush(self): | ||
'''Flush the current configuration into the configuration file.''' | ||
with open(self.file, 'w') as file: | ||
cfg = json.loads(json.dumps(self.data)) | ||
yaml.dump(cfg, file) | ||
|
||
def __str__(self): | ||
cfg = json.loads(json.dumps(self.data)) | ||
return yaml.dump(cfg) | ||
|
||
|
||
config = HubConfig() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
''' | ||
This module is used to store environmental variables for fastdeploy model hub. | ||
FASTDEPLOY_HUB_HOME --> the root directory for storing fastdeploy model hub related data. Default to ~/.fastdeploy. Users can change the | ||
├ default value through the FASTDEPLOY_HUB_HOME environment variable. | ||
├── MODEL_HOME --> Store the downloaded fastdeploy models. | ||
├── CONF_HOME --> Store the default configuration files. | ||
''' | ||
|
||
import os | ||
|
||
|
||
def _get_user_home(): | ||
return os.path.expanduser('~') | ||
|
||
|
||
def _get_hub_home(): | ||
if 'FASTDEPLOY_HUB_HOME' in os.environ: | ||
home_path = os.environ['FASTDEPLOY_HUB_HOME'] | ||
if os.path.exists(home_path): | ||
if os.path.isdir(home_path): | ||
return home_path | ||
else: | ||
raise RuntimeError( | ||
'The environment variable FASTDEPLOY_HUB_HOME {} is not a directory.'. | ||
format(home_path)) | ||
else: | ||
return home_path | ||
return os.path.join(_get_user_home(), '.fastdeploy') | ||
|
||
|
||
def _get_sub_home(directory): | ||
home = os.path.join(_get_hub_home(), directory) | ||
os.makedirs(home, exist_ok=True) | ||
return home | ||
|
||
|
||
USER_HOME = _get_user_home() | ||
HUB_HOME = _get_hub_home() | ||
MODEL_HOME = _get_sub_home('models') | ||
CONF_HOME = _get_sub_home('conf') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import requests | ||
from typing import List | ||
|
||
from fastdeploy.utils.hub_config import config | ||
|
||
|
||
class ServerConnectionError(Exception): | ||
def __init__(self, url: str): | ||
self.url = url | ||
|
||
def __str__(self): | ||
tips = 'Can\'t connect to FastDeploy Model Server: {}'.format(self.url) | ||
return tips | ||
|
||
|
||
class ModelServer(object): | ||
''' | ||
FastDeploy server source | ||
Args: | ||
url(str) : Url of the server | ||
timeout(int) : Request timeout | ||
''' | ||
|
||
def __init__(self, url: str, timeout: int=10): | ||
self._url = url | ||
self._timeout = timeout | ||
|
||
def search_model(self, name: str, format: str=None, | ||
version: str=None) -> List[dict]: | ||
''' | ||
Search model from model server. | ||
Args: | ||
name(str) : FastDeploy model name | ||
format(str): FastDeploy model format | ||
version(str) : FastDeploy model version | ||
Return: | ||
result(list): search results | ||
''' | ||
params = {} | ||
params['name'] = name | ||
if format: | ||
params['format'] = format | ||
if version: | ||
params['version'] = version | ||
result = self.request(path='fastdeploy_search', params=params) | ||
if result['status'] == 0 and len(result['data']) > 0: | ||
return result['data'] | ||
return None | ||
|
||
def stat_model(self, name: str, format: str, version: str): | ||
''' | ||
Note a record when download a model for statistics. | ||
Args: | ||
name(str) : FastDeploy model name | ||
format(str): FastDeploy model format | ||
version(str) : FastDeploy model version | ||
Return: | ||
is_successful(bool): True if successful, False otherwise | ||
''' | ||
params = {} | ||
params['name'] = name | ||
params['format'] = format | ||
params['version'] = version | ||
params['from'] = 'fastdeploy' | ||
try: | ||
result = self.request(path='stat', params=params) | ||
except Exception: | ||
return False | ||
if result['status'] == 0: | ||
return True | ||
else: | ||
return False | ||
|
||
def request(self, path: str, params: dict) -> dict: | ||
'''Request server.''' | ||
api = '{}/{}'.format(self._url, path) | ||
try: | ||
result = requests.get(api, params, timeout=self._timeout) | ||
return result.json() | ||
except requests.exceptions.ConnectionError as e: | ||
raise ServerConnectionError(self._url) | ||
|
||
def is_connected(self): | ||
return self.check(self._url) | ||
|
||
@classmethod | ||
def check(cls, url: str) -> bool: | ||
''' | ||
Check if the specified url is a valid model server | ||
Args: | ||
url(str) : Url to check | ||
''' | ||
try: | ||
r = requests.get(url + '/search') | ||
return r.status_code == 200 | ||
except: | ||
return False | ||
|
||
|
||
model_server = ModelServer(config.server) |