Skip to content

Commit

Permalink
support for sse v1 v2
Browse files Browse the repository at this point in the history
  • Loading branch information
sergix44 committed Feb 13, 2024
1 parent b232f51 commit 0f7b0fc
Show file tree
Hide file tree
Showing 19 changed files with 279 additions and 99 deletions.
2 changes: 1 addition & 1 deletion composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"require": {
"php": "^8.2",
"guzzlehttp/guzzle": "^7.7",
"nutgram/hydrator": ">=5.0",
"nutgram/hydrator": ">=6.0",
"phrity/websocket": "^1.7.2",
"ext-fileinfo": "*"
},
Expand Down
131 changes: 104 additions & 27 deletions src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
use SergiX44\Gradio\Client\Endpoint;
use SergiX44\Gradio\Client\RemoteClient;
use SergiX44\Gradio\DTO\Config;
use SergiX44\Gradio\DTO\Messages\Estimation;
use SergiX44\Gradio\DTO\Messages\Message;
use SergiX44\Gradio\DTO\Messages\ProcessCompleted;
use SergiX44\Gradio\DTO\Messages\ProcessGenerating;
use SergiX44\Gradio\DTO\Messages\ProcessStarts;
use SergiX44\Gradio\DTO\Messages\QueueFull;
use SergiX44\Gradio\DTO\Messages\SendData;
use SergiX44\Gradio\DTO\Messages\SendHash;
use SergiX44\Gradio\DTO\Output;
use SergiX44\Gradio\DTO\Websocket\Estimation;
use SergiX44\Gradio\DTO\Websocket\Message;
use SergiX44\Gradio\DTO\Websocket\ProcessCompleted;
use SergiX44\Gradio\DTO\Websocket\ProcessGenerating;
use SergiX44\Gradio\DTO\Websocket\ProcessStarts;
use SergiX44\Gradio\DTO\Websocket\QueueFull;
use SergiX44\Gradio\DTO\Websocket\SendData;
use SergiX44\Gradio\DTO\Websocket\SendHash;
use SergiX44\Gradio\Event\Event;
use SergiX44\Gradio\Exception\GradioException;
use SergiX44\Gradio\Exception\QueueFullException;
Expand All @@ -23,7 +23,9 @@ class Client extends RemoteClient
{
private const HTTP_PREDICT = 'run/predict';

private const WS_PREDICT = 'queue/join';
private const QUEUE_JOIN = 'queue/join';

private const SSE_GET_DATA = 'queue/data';

private const HTTP_CONFIG = 'config';

Expand All @@ -38,26 +40,19 @@ class Client extends RemoteClient
public function __construct(string $src, string $hfToken = null, Config $config = null)
{
parent::__construct($src);
$this->config = $config ?? $this->get(self::HTTP_CONFIG, dto: Config::class);
$this->config = $config ?? $this->http('get', self::HTTP_CONFIG, dto: Config::class);
$this->loadEndpoints($this->config->dependencies);
$this->sessionHash = substr(md5(microtime()), 0, 11);
$this->hfToken = $hfToken;
}

protected function loadEndpoints(array $dependencies): void
{
foreach ($dependencies as $index => $dep) {
$endpoint = new Endpoint(
$this,
$index,
! empty($dep['api_name']) ? $dep['api_name'] : null,
$dep['queue'] !== false,
count($dep['inputs'])
);

foreach ($dependencies as $index => $dp) {
$endpoint = new Endpoint($this->config, $index, $dp);
$this->endpoints[$index] = $endpoint;
if ($endpoint->apiName !== null) {
$this->endpoints[$endpoint->apiName] = $endpoint;
if ($endpoint->apiName() !== null) {
$this->endpoints[$endpoint->apiName()] = $endpoint;
}
}
}
Expand All @@ -83,16 +78,24 @@ public function predict(array $arguments, string $apiName = null, int $fnIndex =
return $this->submit($endpoint, $arguments);
}

private function submit(Endpoint $endpoint, array $arguments): ?Output
public function submit(Endpoint $endpoint, array $arguments): ?Output
{
$payload = $this->preparePayload($arguments);
$this->fireEvent(Event::SUBMIT, $payload);

if ($endpoint->useWebsockets) {
return $this->websocketLoop($endpoint, $payload);
if ($endpoint->skipsQueue()) {
return $this->http('post', $this->makeUri($endpoint), [
'data' => $payload,
'fn_index' => $endpoint->index,
'session_hash' => $this->sessionHash,
'event_data' => null,
], dto: Output::class);
}

return $this->post(self::HTTP_PREDICT, ['data' => $payload], Output::class);
return match ($this->config->protocol) {
'sse_v1', 'sse_v2' => $this->sseV1V2Loop($endpoint, $payload),
default => $this->websocketLoop($endpoint, $payload),
};
}

private function preparePayload(array $arguments): array
Expand Down Expand Up @@ -124,16 +127,26 @@ private function preparePayload(array $arguments): array
}, $arguments);
}

protected function makeUri(Endpoint $endpoint): string
{
$name = $endpoint->apiName();
if ($name !== null) {
$name = str_replace('/', '', $name);
return "run/$name";
}

return self::HTTP_PREDICT;
}

/**
* @throws GradioException
* @throws QueueFullException
* @throws \JsonException
*/
private function websocketLoop(Endpoint $endpoint, array $payload): ?Output
{
$ws = $this->ws(self::WS_PREDICT);
$ws = $this->ws(self::QUEUE_JOIN);

$message = null;
while (true) {
$data = $ws->receive();

Expand Down Expand Up @@ -183,4 +196,68 @@ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output

return $message?->output;
}

private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
{
$response = $this->httpRaw('post', self::QUEUE_JOIN, [
'data' => $payload,
'fn_index' => $endpoint->index,
'session_hash' => $this->sessionHash,
]);

if ($response->getStatusCode() === 503) {
throw new QueueFullException();
}

if ($response->getStatusCode() !== 200) {
throw new GradioException('Error joining the queue');
}

// $data = $this->decodeResponse($response);
// $eventId = $data['event_id'];

$response = $this->httpRaw('get', self::SSE_GET_DATA, ['session_hash' => $this->sessionHash], [
'headers' => [
'Accept' => 'text/event-stream',
],
'stream' => true,
]);

$buffer = '';
$message = null;
while (!$response->getBody()->eof()) {
$data = $response->getBody()->read(1);
if ($data !== "\n") {
$buffer .= $data;
continue;
}

// read second \n
$response->getBody()->read(1);

// remove data:
$buffer = str_replace('data: ', '', $buffer);
$message = $this->hydrator->hydrateWithJson(Message::class, $buffer);

if ($message instanceof ProcessCompleted) {
$this->fireEvent(Event::PROCESS_COMPLETED, [$message]);
if ($message->success) {
$this->fireEvent(Event::PROCESS_SUCCESS, [$message]);
} else {
$this->fireEvent(Event::PROCESS_FAILED, [$message]);
}
break;
} elseif ($message instanceof ProcessStarts) {
$this->fireEvent(Event::PROCESS_STARTS, [$message]);
} elseif ($message instanceof ProcessGenerating) {
$this->fireEvent(Event::PROCESS_GENERATING, [$message]);
} elseif ($message instanceof Estimation) {
$this->fireEvent(Event::QUEUE_ESTIMATION, [$message]);
}

$buffer = '';
}

return $message?->output;
}
}
29 changes: 24 additions & 5 deletions src/Client/Endpoint.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,35 @@

namespace SergiX44\Gradio\Client;

use SergiX44\Gradio\Client;
use SergiX44\Gradio\DTO\Config;

readonly class Endpoint
{

public function __construct(
public Client $client,
private Config $config,
public int $index,
public ?string $apiName,
public bool $useWebsockets,
public int $argsCount = 1,
private readonly array $data
) {
}

public function __get(string $name): mixed
{
return $this->data[$name] ?? null;
}

public function __isset(string $name): bool
{
return isset($this->data[$name]);
}

public function skipsQueue(): bool
{
return !($this->data['queue'] ?? $this->config->enable_queue);
}

public function apiName(): ?string
{
return !empty($this->data['api_name']) ? $this->data['api_name'] : null;
}
}
31 changes: 16 additions & 15 deletions src/Client/RemoteClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ abstract class RemoteClient extends RegisterEvents
public function __construct(string $src)
{
if (
! str_starts_with($src, 'http://') &&
! str_starts_with($src, 'https://') &&
! str_starts_with($src, 'ws://') &&
! str_starts_with($src, 'wss://')
!str_starts_with($src, 'http://') &&
!str_starts_with($src, 'https://') &&
!str_starts_with($src, 'ws://') &&
!str_starts_with($src, 'wss://')
) {
throw new InvalidArgumentException('The src must not contain the protocol');
}
Expand All @@ -36,37 +36,38 @@ public function __construct(string $src)
'base_uri' => str_replace('ws', 'http', $this->src),
'headers' => [
'User-Agent' => 'gradio_client_php/1.0',
'Accept' => 'application/json',
],
]);
}

protected function get(string $uri, array $params = [], string $dto = null)
protected function http(string $method, string $uri, array $params = [], array $opt = [], ?string $dto = null)
{
$response = $this->httpClient->get($uri, ['query' => $params]);

return $this->parseResponse($response, $dto);
$response = $this->httpRaw($method, $uri, $params, $opt);
return $this->decodeResponse($response, $dto);
}

protected function post(string $uri, array $params = [], string $dto = null)
protected function httpRaw(string $method, string $uri, array $params = [], array $opt = [])
{
$response = $this->httpClient->post($uri, ['json' => $params]);

return $this->parseResponse($response, $dto);
$keyContent = $method === 'get' ? 'query' : 'json';
return $this->httpClient->request($method, $uri, array_merge([
$keyContent => $params,
], $opt));
}

protected function ws(string $uri, array $options = []): EnhancedClient
{
return new EnhancedClient(str_replace('http', 'ws', $this->src).$uri, $options);
}

private function parseResponse(ResponseInterface $response, string $mapTo = null): mixed
protected function decodeResponse(ResponseInterface|string $response, string $mapTo = null): mixed
{
$body = $response->getBody()->getContents();
$body = $response instanceof ResponseInterface ? $response->getBody()->getContents() : $response;

if ($mapTo !== null) {
return $this->hydrator->hydrateWithJson($mapTo, $body);
}

return json_decode($body, flags: JSON_THROW_ON_ERROR);
return json_decode($body, true, flags: JSON_THROW_ON_ERROR);
}
}
19 changes: 19 additions & 0 deletions src/DTO/Config.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,23 @@ class Config
public array $dependencies = [];

public ?string $root = null;

public ?string $protocol = null;

private array $_extra = [];

public function __set(string $name, $value): void
{
$this->_extra[$name] = $value;
}

public function __get(string $name)
{
return $this->_extra[$name] ?? null;
}

public function __isset(string $name): bool
{
return isset($this->_extra[$name]);
}
}
18 changes: 18 additions & 0 deletions src/DTO/Messages/Estimation.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?php

namespace SergiX44\Gradio\DTO\Messages;

class Estimation extends Message
{
public ?int $rank = null;

public ?int $queue_size = null;

public ?float $avg_event_process_time = null;

public ?float $avg_event_concurrent_process_time = null;

public ?float $rank_eta = null;

public ?float $queue_eta = null;
}
12 changes: 12 additions & 0 deletions src/DTO/Messages/Log.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<?php

namespace SergiX44\Gradio\DTO\Messages;

class Log extends Message
{
public ?string $log = null;

public ?string $level = null;

public ?string $event_id = null;
}
31 changes: 31 additions & 0 deletions src/DTO/Messages/Message.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<?php

namespace SergiX44\Gradio\DTO\Messages;

use SergiX44\Gradio\DTO\Resolvers\MessageResolver;
use SergiX44\Gradio\DTO\Resolvers\MessageType;
use SergiX44\Hydrator\Resolver\EnumOrScalar;

#[MessageResolver]
abstract class Message
{
#[EnumOrScalar]
public MessageType|string $msg;

private array $_extra = [];

public function __set(string $name, $value): void
{
$this->_extra[$name] = $value;
}

public function __get(string $name)
{
return $this->_extra[$name] ?? null;
}

public function __isset(string $name): bool
{
return isset($this->_extra[$name]);
}
}
Loading

0 comments on commit 0f7b0fc

Please sign in to comment.