From 0f7b0fc7077a022ff8d1f2f1c47ba87fee095d12 Mon Sep 17 00:00:00 2001 From: Sergio Brighenti Date: Tue, 13 Feb 2024 19:14:48 +0100 Subject: [PATCH] support for sse v1 v2 --- composer.json | 2 +- src/Client.php | 131 ++++++++++++++---- src/Client/Endpoint.php | 29 +++- src/Client/RemoteClient.php | 31 +++-- src/DTO/Config.php | 19 +++ src/DTO/Messages/Estimation.php | 18 +++ src/DTO/Messages/Log.php | 12 ++ src/DTO/Messages/Message.php | 31 +++++ .../ProcessCompleted.php | 2 +- .../ProcessGenerating.php | 2 +- .../{Websocket => Messages}/ProcessStarts.php | 2 +- src/DTO/{Websocket => Messages}/QueueFull.php | 2 +- src/DTO/{Websocket => Messages}/SendData.php | 2 +- src/DTO/{Websocket => Messages}/SendHash.php | 2 +- .../Resolvers}/MessageResolver.php | 21 +-- .../Resolvers}/MessageType.php | 4 +- src/DTO/Websocket/Estimation.php | 18 --- src/DTO/Websocket/Message.php | 12 -- tests/ExampleTest.php | 38 ++++- 19 files changed, 279 insertions(+), 99 deletions(-) create mode 100644 src/DTO/Messages/Estimation.php create mode 100644 src/DTO/Messages/Log.php create mode 100644 src/DTO/Messages/Message.php rename src/DTO/{Websocket => Messages}/ProcessCompleted.php (78%) rename src/DTO/{Websocket => Messages}/ProcessGenerating.php (79%) rename src/DTO/{Websocket => Messages}/ProcessStarts.php (53%) rename src/DTO/{Websocket => Messages}/QueueFull.php (51%) rename src/DTO/{Websocket => Messages}/SendData.php (51%) rename src/DTO/{Websocket => Messages}/SendHash.php (51%) rename src/{Websocket => DTO/Resolvers}/MessageResolver.php (59%) rename src/{Websocket => DTO/Resolvers}/MessageType.php (83%) delete mode 100644 src/DTO/Websocket/Estimation.php delete mode 100644 src/DTO/Websocket/Message.php diff --git a/composer.json b/composer.json index 3ea6953..9833f40 100644 --- a/composer.json +++ b/composer.json @@ -17,7 +17,7 @@ "require": { "php": "^8.2", "guzzlehttp/guzzle": "^7.7", - "nutgram/hydrator": ">=5.0", + "nutgram/hydrator": ">=6.0", "phrity/websocket": "^1.7.2", "ext-fileinfo": "*" }, diff --git a/src/Client.php b/src/Client.php index 896de05..0aee72c 100644 --- a/src/Client.php +++ b/src/Client.php @@ -6,15 +6,15 @@ use SergiX44\Gradio\Client\Endpoint; use SergiX44\Gradio\Client\RemoteClient; use SergiX44\Gradio\DTO\Config; +use SergiX44\Gradio\DTO\Messages\Estimation; +use SergiX44\Gradio\DTO\Messages\Message; +use SergiX44\Gradio\DTO\Messages\ProcessCompleted; +use SergiX44\Gradio\DTO\Messages\ProcessGenerating; +use SergiX44\Gradio\DTO\Messages\ProcessStarts; +use SergiX44\Gradio\DTO\Messages\QueueFull; +use SergiX44\Gradio\DTO\Messages\SendData; +use SergiX44\Gradio\DTO\Messages\SendHash; use SergiX44\Gradio\DTO\Output; -use SergiX44\Gradio\DTO\Websocket\Estimation; -use SergiX44\Gradio\DTO\Websocket\Message; -use SergiX44\Gradio\DTO\Websocket\ProcessCompleted; -use SergiX44\Gradio\DTO\Websocket\ProcessGenerating; -use SergiX44\Gradio\DTO\Websocket\ProcessStarts; -use SergiX44\Gradio\DTO\Websocket\QueueFull; -use SergiX44\Gradio\DTO\Websocket\SendData; -use SergiX44\Gradio\DTO\Websocket\SendHash; use SergiX44\Gradio\Event\Event; use SergiX44\Gradio\Exception\GradioException; use SergiX44\Gradio\Exception\QueueFullException; @@ -23,7 +23,9 @@ class Client extends RemoteClient { private const HTTP_PREDICT = 'run/predict'; - private const WS_PREDICT = 'queue/join'; + private const QUEUE_JOIN = 'queue/join'; + + private const SSE_GET_DATA = 'queue/data'; private const HTTP_CONFIG = 'config'; @@ -38,7 +40,7 @@ class Client extends RemoteClient public function __construct(string $src, string $hfToken = null, Config $config = null) { parent::__construct($src); - $this->config = $config ?? $this->get(self::HTTP_CONFIG, dto: Config::class); + $this->config = $config ?? $this->http('get', self::HTTP_CONFIG, dto: Config::class); $this->loadEndpoints($this->config->dependencies); $this->sessionHash = substr(md5(microtime()), 0, 11); $this->hfToken = $hfToken; @@ -46,18 +48,11 @@ public function __construct(string $src, string $hfToken = null, Config $config protected function loadEndpoints(array $dependencies): void { - foreach ($dependencies as $index => $dep) { - $endpoint = new Endpoint( - $this, - $index, - ! empty($dep['api_name']) ? $dep['api_name'] : null, - $dep['queue'] !== false, - count($dep['inputs']) - ); - + foreach ($dependencies as $index => $dp) { + $endpoint = new Endpoint($this->config, $index, $dp); $this->endpoints[$index] = $endpoint; - if ($endpoint->apiName !== null) { - $this->endpoints[$endpoint->apiName] = $endpoint; + if ($endpoint->apiName() !== null) { + $this->endpoints[$endpoint->apiName()] = $endpoint; } } } @@ -83,16 +78,24 @@ public function predict(array $arguments, string $apiName = null, int $fnIndex = return $this->submit($endpoint, $arguments); } - private function submit(Endpoint $endpoint, array $arguments): ?Output + public function submit(Endpoint $endpoint, array $arguments): ?Output { $payload = $this->preparePayload($arguments); $this->fireEvent(Event::SUBMIT, $payload); - if ($endpoint->useWebsockets) { - return $this->websocketLoop($endpoint, $payload); + if ($endpoint->skipsQueue()) { + return $this->http('post', $this->makeUri($endpoint), [ + 'data' => $payload, + 'fn_index' => $endpoint->index, + 'session_hash' => $this->sessionHash, + 'event_data' => null, + ], dto: Output::class); } - return $this->post(self::HTTP_PREDICT, ['data' => $payload], Output::class); + return match ($this->config->protocol) { + 'sse_v1', 'sse_v2' => $this->sseV1V2Loop($endpoint, $payload), + default => $this->websocketLoop($endpoint, $payload), + }; } private function preparePayload(array $arguments): array @@ -124,6 +127,17 @@ private function preparePayload(array $arguments): array }, $arguments); } + protected function makeUri(Endpoint $endpoint): string + { + $name = $endpoint->apiName(); + if ($name !== null) { + $name = str_replace('/', '', $name); + return "run/$name"; + } + + return self::HTTP_PREDICT; + } + /** * @throws GradioException * @throws QueueFullException @@ -131,9 +145,8 @@ private function preparePayload(array $arguments): array */ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output { - $ws = $this->ws(self::WS_PREDICT); + $ws = $this->ws(self::QUEUE_JOIN); - $message = null; while (true) { $data = $ws->receive(); @@ -183,4 +196,68 @@ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output return $message?->output; } + + private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output + { + $response = $this->httpRaw('post', self::QUEUE_JOIN, [ + 'data' => $payload, + 'fn_index' => $endpoint->index, + 'session_hash' => $this->sessionHash, + ]); + + if ($response->getStatusCode() === 503) { + throw new QueueFullException(); + } + + if ($response->getStatusCode() !== 200) { + throw new GradioException('Error joining the queue'); + } + +// $data = $this->decodeResponse($response); +// $eventId = $data['event_id']; + + $response = $this->httpRaw('get', self::SSE_GET_DATA, ['session_hash' => $this->sessionHash], [ + 'headers' => [ + 'Accept' => 'text/event-stream', + ], + 'stream' => true, + ]); + + $buffer = ''; + $message = null; + while (!$response->getBody()->eof()) { + $data = $response->getBody()->read(1); + if ($data !== "\n") { + $buffer .= $data; + continue; + } + + // read second \n + $response->getBody()->read(1); + + // remove data: + $buffer = str_replace('data: ', '', $buffer); + $message = $this->hydrator->hydrateWithJson(Message::class, $buffer); + + if ($message instanceof ProcessCompleted) { + $this->fireEvent(Event::PROCESS_COMPLETED, [$message]); + if ($message->success) { + $this->fireEvent(Event::PROCESS_SUCCESS, [$message]); + } else { + $this->fireEvent(Event::PROCESS_FAILED, [$message]); + } + break; + } elseif ($message instanceof ProcessStarts) { + $this->fireEvent(Event::PROCESS_STARTS, [$message]); + } elseif ($message instanceof ProcessGenerating) { + $this->fireEvent(Event::PROCESS_GENERATING, [$message]); + } elseif ($message instanceof Estimation) { + $this->fireEvent(Event::QUEUE_ESTIMATION, [$message]); + } + + $buffer = ''; + } + + return $message?->output; + } } diff --git a/src/Client/Endpoint.php b/src/Client/Endpoint.php index 760d1c5..68e7f86 100644 --- a/src/Client/Endpoint.php +++ b/src/Client/Endpoint.php @@ -2,16 +2,35 @@ namespace SergiX44\Gradio\Client; -use SergiX44\Gradio\Client; +use SergiX44\Gradio\DTO\Config; readonly class Endpoint { + public function __construct( - public Client $client, + private Config $config, public int $index, - public ?string $apiName, - public bool $useWebsockets, - public int $argsCount = 1, + private readonly array $data ) { } + + public function __get(string $name): mixed + { + return $this->data[$name] ?? null; + } + + public function __isset(string $name): bool + { + return isset($this->data[$name]); + } + + public function skipsQueue(): bool + { + return !($this->data['queue'] ?? $this->config->enable_queue); + } + + public function apiName(): ?string + { + return !empty($this->data['api_name']) ? $this->data['api_name'] : null; + } } diff --git a/src/Client/RemoteClient.php b/src/Client/RemoteClient.php index 721918d..bbaa01a 100644 --- a/src/Client/RemoteClient.php +++ b/src/Client/RemoteClient.php @@ -20,10 +20,10 @@ abstract class RemoteClient extends RegisterEvents public function __construct(string $src) { if ( - ! str_starts_with($src, 'http://') && - ! str_starts_with($src, 'https://') && - ! str_starts_with($src, 'ws://') && - ! str_starts_with($src, 'wss://') + !str_starts_with($src, 'http://') && + !str_starts_with($src, 'https://') && + !str_starts_with($src, 'ws://') && + !str_starts_with($src, 'wss://') ) { throw new InvalidArgumentException('The src must not contain the protocol'); } @@ -36,22 +36,23 @@ public function __construct(string $src) 'base_uri' => str_replace('ws', 'http', $this->src), 'headers' => [ 'User-Agent' => 'gradio_client_php/1.0', + 'Accept' => 'application/json', ], ]); } - protected function get(string $uri, array $params = [], string $dto = null) + protected function http(string $method, string $uri, array $params = [], array $opt = [], ?string $dto = null) { - $response = $this->httpClient->get($uri, ['query' => $params]); - - return $this->parseResponse($response, $dto); + $response = $this->httpRaw($method, $uri, $params, $opt); + return $this->decodeResponse($response, $dto); } - protected function post(string $uri, array $params = [], string $dto = null) + protected function httpRaw(string $method, string $uri, array $params = [], array $opt = []) { - $response = $this->httpClient->post($uri, ['json' => $params]); - - return $this->parseResponse($response, $dto); + $keyContent = $method === 'get' ? 'query' : 'json'; + return $this->httpClient->request($method, $uri, array_merge([ + $keyContent => $params, + ], $opt)); } protected function ws(string $uri, array $options = []): EnhancedClient @@ -59,14 +60,14 @@ protected function ws(string $uri, array $options = []): EnhancedClient return new EnhancedClient(str_replace('http', 'ws', $this->src).$uri, $options); } - private function parseResponse(ResponseInterface $response, string $mapTo = null): mixed + protected function decodeResponse(ResponseInterface|string $response, string $mapTo = null): mixed { - $body = $response->getBody()->getContents(); + $body = $response instanceof ResponseInterface ? $response->getBody()->getContents() : $response; if ($mapTo !== null) { return $this->hydrator->hydrateWithJson($mapTo, $body); } - return json_decode($body, flags: JSON_THROW_ON_ERROR); + return json_decode($body, true, flags: JSON_THROW_ON_ERROR); } } diff --git a/src/DTO/Config.php b/src/DTO/Config.php index b2d0e33..bd73dd0 100644 --- a/src/DTO/Config.php +++ b/src/DTO/Config.php @@ -33,4 +33,23 @@ class Config public array $dependencies = []; public ?string $root = null; + + public ?string $protocol = null; + + private array $_extra = []; + + public function __set(string $name, $value): void + { + $this->_extra[$name] = $value; + } + + public function __get(string $name) + { + return $this->_extra[$name] ?? null; + } + + public function __isset(string $name): bool + { + return isset($this->_extra[$name]); + } } diff --git a/src/DTO/Messages/Estimation.php b/src/DTO/Messages/Estimation.php new file mode 100644 index 0000000..4eab53f --- /dev/null +++ b/src/DTO/Messages/Estimation.php @@ -0,0 +1,18 @@ +_extra[$name] = $value; + } + + public function __get(string $name) + { + return $this->_extra[$name] ?? null; + } + + public function __isset(string $name): bool + { + return isset($this->_extra[$name]); + } +} diff --git a/src/DTO/Websocket/ProcessCompleted.php b/src/DTO/Messages/ProcessCompleted.php similarity index 78% rename from src/DTO/Websocket/ProcessCompleted.php rename to src/DTO/Messages/ProcessCompleted.php index 13f2b55..07d6b54 100644 --- a/src/DTO/Websocket/ProcessCompleted.php +++ b/src/DTO/Messages/ProcessCompleted.php @@ -1,6 +1,6 @@ value => ProcessStarts::class, MessageType::PROCESS_GENERATING->value => ProcessGenerating::class, MessageType::PROCESS_COMPLETED->value => ProcessCompleted::class, - default => throw new InvalidArgumentException('Unknown msg type'), + MessageType::LOG->value => Log::class, + default => (new class extends Message {})::class, }; } } diff --git a/src/Websocket/MessageType.php b/src/DTO/Resolvers/MessageType.php similarity index 83% rename from src/Websocket/MessageType.php rename to src/DTO/Resolvers/MessageType.php index ab2d58e..6052549 100644 --- a/src/Websocket/MessageType.php +++ b/src/DTO/Resolvers/MessageType.php @@ -1,6 +1,6 @@ predict([ - // 'banana and lemon', '', 7.5, 25, 1234, - // ], fnIndex: 4); + $response = $client->predict([ + "house", // string in 'Prompt' Textbox component + "!", // string in 'Negative prompt' Textbox component + 0, // number (numeric value between 0 and 2147483647) in 'Seed' Slider component + 1024, // number (numeric value between 1024 and 1536) in 'Width' Slider component + 1024, // number (numeric value between 1024 and 1536) in 'Height' Slider component + 10, // number (numeric value between 10 and 30) in 'Prior Inference Steps' Slider component + 0, // number (numeric value between 0 and 20) in 'Prior Guidance Scale' Slider component + 4, // number (numeric value between 4 and 12) in 'Decoder Inference Steps' Slider component + 0, // number (numeric value between 0 and 0) in 'Decoder Guidance Scale' Slider component + 1, // number (numeric value between 1 and 2) in 'Number of Images' Slider component + ], '/run'); - expect($c)->toBeInstanceOf(Client::class); + $outputs = $response->getOutputs(); + + expect($client)->toBeInstanceOf(Client::class); +}); + +it('can test another model', function () { + $client = new Client('https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/brc3o/'); + + $response = $client->predict([ + 'list all names of the week in all languages', // str in 'parameter_28' Textbox component + '', // str in 'Optional system prompt' Textbox component + 0.9, // float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component + 4096, // float (numeric value between 0 and 4096) in 'Max new tokens' Slider component + 0.6, // float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component + 1.2, // float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component + ], '/chat'); + + $outputs = $response->getOutputs(); + + expect($client)->toBeInstanceOf(Client::class); });