Skip to content

Commit

Permalink
Improve local version starting (#79)
Browse files Browse the repository at this point in the history
* Add verbose mode to local server start

* Quick fix

* Use threading

* Use select

* Try to fix

* Add progress counting

* Show a loading animation and add some warnings

* Run formatter

* Make it possible to restart after stopping
  • Loading branch information
hugoabonizio committed Apr 24, 2024
1 parent 0b3fbaf commit 589dcdf
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 17 deletions.
2 changes: 1 addition & 1 deletion maritalk/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"No libcublas.so found. cuBLAS v11 or v12 is required to run MariTalk. You can manually set the version using the `cuda_version` argument."
)

bin_folder = os.path.dirname(args.path)
bin_folder = os.path.dirname(os.path.expanduser(args.path))
if bin_folder:
os.makedirs(bin_folder, exist_ok=True)
download(args.license, args.path, dependencies)
123 changes: 108 additions & 15 deletions maritalk/resources/local.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import re
import csv
import sys
import time
import atexit
import threading
import subprocess
from tqdm import tqdm
from pathlib import Path
Expand All @@ -15,15 +17,15 @@ def check_gpu():
try:
result = subprocess.run(
[
'nvidia-smi',
'--query-gpu=name,compute_cap',
'--format=csv',
"nvidia-smi",
"--query-gpu=name,compute_cap",
"--format=csv",
],
capture_output=True,
text=True,
check=True,
)
reader = csv.reader(result.stdout.strip().split('\n'))
reader = csv.reader(result.stdout.strip().split("\n"))
headers = next(reader)
rows = list(reader)

Expand All @@ -50,19 +52,21 @@ def find_libs():

try:
output = subprocess.run(
['nvidia-smi'],
["nvidia-smi"],
stdout=subprocess.PIPE,
).stdout.decode('utf-8')
).stdout.decode("utf-8")
cuda_version_match = re.search(r"CUDA Version: (\d+\.\d+)", output)

if not cuda_version_match:
raise Exception("""Could not automatically detect the CUDA version. Verify the CUDA Toolkit installation or set the `cuda_version` parameter manually. For example:
raise Exception(
"""Could not automatically detect the CUDA version. Verify the CUDA Toolkit installation or set the `cuda_version` parameter manually. For example:
```
model.start_server("<YOUR LICENSE>", cuda_version="12.3")
```
To install the CUDA Toolkit, please refer to: https://developer.nvidia.com/cuda-downloads""")
To install the CUDA Toolkit, please refer to: https://developer.nvidia.com/cuda-downloads"""
)

versions["cuda_version"] = cuda_version_match.group(1)
except subprocess.CalledProcessError as e:
Expand All @@ -76,9 +80,7 @@ def find_libs():


def download(license: str, bin_path: str, dependencies: Dict[str, int]):
download_url = (
"https://functions.maritaca.ai/local/download"
)
download_url = "https://functions.maritaca.ai/local/download"
response = requests.post(
download_url,
json={
Expand Down Expand Up @@ -124,6 +126,38 @@ def download(license: str, bin_path: str, dependencies: Dict[str, int]):
raise Exception(f"Error downloading MariTalk binary: {e}")


def _get_total_mem():
try:
output = subprocess.check_output(["free", "-h"], text=True)
for line in output.splitlines():
if line.startswith("Mem:"):
mem_info = line.split()[1]
return _convert_to_gb(mem_info)
return None
except (subprocess.CalledProcessError, FileNotFoundError):
return None


def _convert_to_gb(mem_str):
"""
Convert memory string from `free -h` (like 251Gi) to gigabytes as a float.
"""
match = re.match(r"([0-9.]+)([KMGTPE]i)", mem_str)
if match:
value, unit = match.groups()
unit_factor = {"Mi": 1 / 1024, "Gi": 1, "Ti": 1024}
return float(value) * unit_factor[unit]
return None


def _get_file_size(file_path):
try:
file_size_bytes = os.path.getsize(file_path)
return file_size_bytes / (1024**3)
except Exception:
return None


def start_server(
license: str,
bin_path: str = "~/bin/maritalk",
Expand All @@ -150,11 +184,21 @@ def start_server(
os.makedirs(bin_folder, exist_ok=True)
download(license, bin_path, dependencies)

bin_size = _get_file_size(bin_path)

if bin_size:
min_memory = 30 if bin_size < 20 else 130
memory_available = _get_total_mem()
if memory_available and memory_available < min_memory:
print(
"WARNING: Verify that there is enough memory to load the model (at least 30 GB for the small version and 130 GB for the medium version)."
)

args = [bin_path, "--license", license, "--port", str(port)]
return subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stderr=subprocess.PIPE,
)


Expand All @@ -166,20 +210,36 @@ def __init__(self, host: str = "localhost", port: int = 9000):
"""@private"""
self.process = None
"""@private"""
self.loading = False
"""@private"""
self.loaded = False
"""@private"""

def start_server(
self,
license: str,
bin_path: str = "~/bin/maritalk",
cuda_version: Optional[int] = None,
verbose: str = True,
):
print(f"Starting MariTalk Local API at http://localhost:{self.port}")
if self.loaded:
return

self.loading = True
if verbose:
print(f"Starting MariTalk Local API at http://localhost:{self.port}/")
print(
"This process can take a few minutes (up to 10 minutes for the small version, depending on the hardware)."
)
loading_thread = threading.Thread(target=self._show_loading)
loading_thread.start()
self.process = start_server(license, bin_path, cuda_version, self.port)

while True:
try:
if self.process.poll() is not None:
output, _ = self.process.communicate()
output = output.decode('utf-8')
output = output.decode("utf-8")
raise Exception(
f"Failed to start process.\nOutput: {output}\nTry to run it manually: `{' '.join(self.process.args)}`"
)
Expand All @@ -189,18 +249,49 @@ def start_server(
except ConnectionError as ex:
time.sleep(1)

if verbose:
loading_thread.join()
print()

self.loading = False
self.loaded = True

def terminate():
print("Stopping MariTalk...")
self.stop_server()

atexit.register(terminate)

def _show_loading(self):
spinner = ["⠋", "⠙", "⠚", "⠞", "⠖", "⠦", "⠴", "⠲", "⠳", "⠓"]
spinner_index = 0
start_time = time.time()

try:
while self.loading:
current_time = time.time()
elapsed_time = int(current_time - start_time)
minutes, seconds = divmod(elapsed_time, 60)

output = (
f"\rLoading... {spinner[spinner_index]} ({minutes}min:{seconds}s)"
)
sys.stdout.write(output)
sys.stdout.flush()
spinner_index = (spinner_index + 1) % len(
spinner
)
time.sleep(0.1)
except KeyboardInterrupt:
sys.stdout.flush()

def stop_server(self):
if not self.process:
print("No process attached to this client!")
return
self.process.terminate()
self.process = None
self.loaded = False

def status(self):
response = requests.get(self.api_url)
Expand Down Expand Up @@ -326,4 +417,6 @@ def generate(
response.raise_for_status()

def generate_chat(self, *args, **kwargs):
raise Exception('This method was changed, please use `generate` for chat messages or `generate_raw` for raw few-shot examples instead.')
raise Exception(
"This method was changed, please use `generate` for chat messages or `generate_raw` for raw few-shot examples instead."
)
2 changes: 1 addition & 1 deletion maritalk/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@
if process.poll() is not None and output == b"":
break
if output:
print(output.decode().strip())
print(output.decode().strip(), end='')

0 comments on commit 589dcdf

Please sign in to comment.