From 5625cbe2d1e4b873704ca14d53c0d1e5fa51a21c Mon Sep 17 00:00:00 2001 From: Sergio Brighenti Date: Tue, 13 Feb 2024 20:00:12 +0100 Subject: [PATCH] allow for raw requests --- src/Client.php | 23 +++++------------------ src/Client/Endpoint.php | 12 ++++++++++++ tests/ExampleTest.php | 13 +++++++++++++ 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/Client.php b/src/Client.php index 3ead71b..834980c 100644 --- a/src/Client.php +++ b/src/Client.php @@ -62,7 +62,7 @@ public function getConfig(): Config return $this->config; } - public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null): ?Output + public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null, bool $raw = false): Output|array|null { if ($apiName === null && $fnIndex === null) { throw new InvalidArgumentException('You must provide an apiName or fnIndex'); @@ -75,21 +75,20 @@ public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex throw new InvalidArgumentException('Endpoint not found'); } - return $this->submit($endpoint, $arguments); + return $this->submit($endpoint, $arguments, $raw); } - public function submit(Endpoint $endpoint, array $arguments): ?Output + protected function submit(Endpoint $endpoint, array $arguments, bool $raw): Output|array|null { $payload = $this->preparePayload($arguments); $this->fireEvent(Event::SUBMIT, $payload); if ($endpoint->skipsQueue()) { - return $this->http('post', $this->makeUri($endpoint), [ + return $this->http('post', $endpoint->uri(), [ 'data' => $payload, 'fn_index' => $endpoint->index, 'session_hash' => $this->sessionHash, - 'event_data' => null, - ], dto: Output::class); + ], dto: $raw ? null : Output::class); } return match ($this->config->protocol) { @@ -127,18 +126,6 @@ private function preparePayload(array $arguments): array }, $arguments); } - protected function makeUri(Endpoint $endpoint): string - { - $name = $endpoint->apiName(); - if ($name !== null) { - $name = str_replace('/', '', $name); - - return "run/$name"; - } - - return self::HTTP_PREDICT; - } - /** * @throws GradioException * @throws QueueFullException diff --git a/src/Client/Endpoint.php b/src/Client/Endpoint.php index 646058c..3678eab 100644 --- a/src/Client/Endpoint.php +++ b/src/Client/Endpoint.php @@ -32,4 +32,16 @@ public function apiName(): ?string { return ! empty($this->data['api_name']) ? $this->data['api_name'] : null; } + + public function uri() + { + $name = $this->apiName(); + if ($name !== null) { + $name = str_replace('/', '', $name); + + return "run/$name"; + } + + return 'run/predict'; + } } diff --git a/tests/ExampleTest.php b/tests/ExampleTest.php index 4bd886b..c46b476 100644 --- a/tests/ExampleTest.php +++ b/tests/ExampleTest.php @@ -39,3 +39,16 @@ expect($client)->toBeInstanceOf(Client::class); }); + +it('can test fnindexsudgugdhs', function () { + $client = new Client('https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/brc3o/'); + + $client->predict([], fnIndex: 6, raw: true); + $client->predict(['hi'], fnIndex: 2, raw: true); + $client->predict([null, null], fnIndex: 3, raw: true); + $response = $client->predict([null, null, "", 0.9, 256, 0.6, 1.2], fnIndex: 4); + + $value = $response->getOutput(); + + expect($value)->toBeArray(); +});