Skip to content

Commit

Permalink
Improve sse protocol support
Browse files Browse the repository at this point in the history
  • Loading branch information
sergix44 committed Mar 23, 2024
1 parent c4118c0 commit 8262fd3
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 36 deletions.
73 changes: 52 additions & 21 deletions src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Client extends RemoteClient

private const QUEUE_JOIN = 'queue/join';

private const SSE_GET_DATA = 'queue/data';
private const SSE_QUEUE_DATA = 'queue/data';

private const HTTP_CONFIG = 'config';

Expand Down Expand Up @@ -62,7 +62,7 @@ public function getConfig(): Config
return $this->config;
}

public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null, bool $raw = false): Output|array|null
public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null, bool $raw = false, ?int $triggerId = null): Output|array|null
{
if ($apiName === null && $fnIndex === null) {
throw new InvalidArgumentException('You must provide an apiName or fnIndex');
Expand All @@ -75,10 +75,10 @@ public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex
throw new InvalidArgumentException('Endpoint not found');
}

return $this->submit($endpoint, $arguments, $raw);
return $this->submit($endpoint, $arguments, $raw, $triggerId);
}

protected function submit(Endpoint $endpoint, array $arguments, bool $raw): Output|array|null
protected function submit(Endpoint $endpoint, array $arguments, bool $raw, ?int $triggerId = null): Output|array|null
{
$payload = $this->preparePayload($arguments);
$this->fireEvent(Event::SUBMIT, $payload);
Expand All @@ -88,12 +88,15 @@ protected function submit(Endpoint $endpoint, array $arguments, bool $raw): Outp
'data' => $payload,
'fn_index' => $endpoint->index,
'session_hash' => $this->sessionHash,
'trigger_id' => $triggerId,
'event_data' => null,
], dto: $raw ? null : Output::class);
}

return match ($this->config->protocol) {
'sse_v1', 'sse_v2' => $this->sseV1V2Loop($endpoint, $payload),
default => $this->websocketLoop($endpoint, $payload),
'sse', 'sse_v1', 'sse_v2', 'sse_v2.1', 'sse_v3' => $this->sseLoop($endpoint, $payload, $this->config->protocol, $triggerId),
'ws' => $this->websocketLoop($endpoint, $payload),
default => throw new GradioException('Unknown protocol '.$this->config->protocol),
};
}

Expand Down Expand Up @@ -185,26 +188,34 @@ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output
return $message?->output;
}

private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
private function sseLoop(Endpoint $endpoint, array $payload, string $protocol, ?int $triggerId): ?Output
{
$response = $this->httpRaw('post', self::QUEUE_JOIN, [
'data' => $payload,
'fn_index' => $endpoint->index,
'session_hash' => $this->sessionHash,
]);
if ($protocol === 'sse') {
$getEndpoint = self::QUEUE_JOIN;
} else {
$getEndpoint = self::SSE_QUEUE_DATA;
$response = $this->httpRaw('post', self::QUEUE_JOIN, [
'data' => $payload,
'fn_index' => $endpoint->index,
'session_hash' => $this->sessionHash,
]);

if ($response->getStatusCode() === 503) {
throw new QueueFullException();
}

if ($response->getStatusCode() !== 200) {
throw new GradioException('Error joining the queue');
if ($response->getStatusCode() === 503) {
throw new QueueFullException();
}

if ($response->getStatusCode() !== 200) {
throw new GradioException('Error joining the queue');
}
}

// $data = $this->decodeResponse($response);
// $eventId = $data['event_id'];
$params = ['session_hash' => $this->sessionHash];
if ($protocol === 'sse') {
$params['fn_index'] = $endpoint->index;
}

$response = $this->httpRaw('get', self::SSE_GET_DATA, ['session_hash' => $this->sessionHash], [
$response = $this->httpRaw('get', $getEndpoint, $params, [
'headers' => [
'Accept' => 'text/event-stream',
],
Expand All @@ -213,7 +224,7 @@ private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output

$buffer = '';
$message = null;
while (! $response->getBody()->eof()) {
while (!$response->getBody()->eof()) {
$data = $response->getBody()->read(1);
if ($data !== "\n") {
$buffer .= $data;
Expand All @@ -228,7 +239,27 @@ private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
$buffer = str_replace('data: ', '', $buffer);
$message = $this->hydrator->hydrateWithJson(Message::class, $buffer);

if ($message instanceof SendData && $protocol === 'sse') {
$sendData = $this->httpRaw('post', self::SSE_QUEUE_DATA, [
'data' => $payload,
'fn_index' => $endpoint->index,
'session_hash' => $this->sessionHash,
'event_id' => $message->event_id,
'event_data' => $message?->event_data,
'trigger_id' => $triggerId,
]);
if ($sendData->getStatusCode() !== 200) {
throw new GradioException('Error sending data');
}
$buffer = '';
continue;
}

if ($message instanceof ProcessCompleted) {
if (in_array($protocol, ['sse_v2', 'sse_v2.1'], true)) {
$response->getBody()->close();
}

$this->fireEvent(Event::PROCESS_COMPLETED, [$message]);
if ($message->success) {
$this->fireEvent(Event::PROCESS_SUCCESS, [$message]);
Expand Down
1 change: 1 addition & 0 deletions src/DTO/Messages/SendData.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@

class SendData extends Message
{
public ?string $event_id = null;
}
17 changes: 17 additions & 0 deletions src/DTO/Output.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@ class Output

public array $data = [];

private array $_extra = [];

public function __set(string $name, $value): void
{
$this->_extra[$name] = $value;
}

public function __get(string $name)
{
return $this->_extra[$name] ?? null;
}

public function __isset(string $name): bool
{
return isset($this->_extra[$name]);
}

public function getOutputs(): array
{
return $this->data ?? [];
Expand Down
23 changes: 8 additions & 15 deletions tests/ExampleTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,22 @@
});

it('can test another model', function () {
$client = new Client('https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/brc3o/');
$client = new Client('https://ehristoforu-mixtral-46-7b-chat.hf.space');

$response = $client->predict([
'list all names of the week in all languages', // str in 'parameter_28' Textbox component
'', // str in 'Optional system prompt' Textbox component
0.9, // float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
4096, // float (numeric value between 0 and 4096) in 'Max new tokens' Slider component
0.6, // float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
1.2, // float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
], '/chat');
$client->predict([], fnIndex: 5, raw: true);
$client->predict(['hi'], fnIndex: 1, raw: true);
$client->predict([null, []], fnIndex: 2, raw: true);
$response = $client->predict([null, null, "", 0.9, 256, 0.9, 1.2], fnIndex: 3);
$client->predict([], fnIndex: 6, raw: true);

$outputs = $response->getOutputs();

expect($client)->toBeInstanceOf(Client::class);
});

it('can test fnindexsudgugdhs', function () {
$client = new Client('https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/brc3o/');

$client->predict([], fnIndex: 6, raw: true);
$client->predict(['hi'], fnIndex: 2, raw: true);
$client->predict([null, null], fnIndex: 3, raw: true);
$response = $client->predict([null, null, '', 0.9, 256, 0.6, 1.2], fnIndex: 4);
$client = new Client('https://deepseek-ai-deepseek-vl-7b.hf.space');
$response = $client->predict([[["Hello!", null]], 0, 0, 0, 0, 0, 'DeepSeek-VL 7B'], apiName: '/predict');

$value = $response->getOutput();

Expand Down

0 comments on commit 8262fd3

Please sign in to comment.