Skip to content

Commit

Permalink
allow for raw requests
Browse files Browse the repository at this point in the history
  • Loading branch information
sergix44 committed Feb 13, 2024
1 parent fc64add commit 5625cbe
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
23 changes: 5 additions & 18 deletions src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public function getConfig(): Config
return $this->config;
}

public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null): ?Output
public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null, bool $raw = false): Output|array|null
{
if ($apiName === null && $fnIndex === null) {
throw new InvalidArgumentException('You must provide an apiName or fnIndex');
Expand All @@ -75,21 +75,20 @@ public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex
throw new InvalidArgumentException('Endpoint not found');
}

return $this->submit($endpoint, $arguments);
return $this->submit($endpoint, $arguments, $raw);
}

public function submit(Endpoint $endpoint, array $arguments): ?Output
protected function submit(Endpoint $endpoint, array $arguments, bool $raw): Output|array|null
{
$payload = $this->preparePayload($arguments);
$this->fireEvent(Event::SUBMIT, $payload);

if ($endpoint->skipsQueue()) {
return $this->http('post', $this->makeUri($endpoint), [
return $this->http('post', $endpoint->uri(), [
'data' => $payload,
'fn_index' => $endpoint->index,
'session_hash' => $this->sessionHash,
'event_data' => null,
], dto: Output::class);
], dto: $raw ? null : Output::class);
}

return match ($this->config->protocol) {
Expand Down Expand Up @@ -127,18 +126,6 @@ private function preparePayload(array $arguments): array
}, $arguments);
}

protected function makeUri(Endpoint $endpoint): string
{
$name = $endpoint->apiName();
if ($name !== null) {
$name = str_replace('/', '', $name);

return "run/$name";
}

return self::HTTP_PREDICT;
}

/**
* @throws GradioException
* @throws QueueFullException
Expand Down
12 changes: 12 additions & 0 deletions src/Client/Endpoint.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,16 @@ public function apiName(): ?string
{
return ! empty($this->data['api_name']) ? $this->data['api_name'] : null;
}

public function uri()
{
$name = $this->apiName();
if ($name !== null) {
$name = str_replace('/', '', $name);

return "run/$name";
}

return 'run/predict';
}
}
13 changes: 13 additions & 0 deletions tests/ExampleTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,16 @@

expect($client)->toBeInstanceOf(Client::class);
});

it('can test fnindexsudgugdhs', function () {
$client = new Client('https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/brc3o/');

$client->predict([], fnIndex: 6, raw: true);
$client->predict(['hi'], fnIndex: 2, raw: true);
$client->predict([null, null], fnIndex: 3, raw: true);
$response = $client->predict([null, null, "", 0.9, 256, 0.6, 1.2], fnIndex: 4);

$value = $response->getOutput();

expect($value)->toBeArray();
});

0 comments on commit 5625cbe

Please sign in to comment.