diff --git a/src/groq/resources/chat/completions.py b/src/groq/resources/chat/completions.py index 2e199f2..d171b64 100644 --- a/src/groq/resources/chat/completions.py +++ b/src/groq/resources/chat/completions.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Dict, List, Union, Iterable, Optional +from typing import Dict, List, Union, Iterable, Optional, overload +from typing_extensions import Literal import httpx @@ -19,10 +20,12 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..._streaming import Stream, AsyncStream from ...types.chat import completion_create_params from ..._base_client import ( make_request_options, ) +from ...lib.chat_completion_chunk import ChatCompletionChunk from ...types.chat.chat_completion import ChatCompletion __all__ = ["Completions", "AsyncCompletions"] @@ -37,6 +40,7 @@ def with_raw_response(self) -> CompletionsWithRawResponse: def with_streaming_response(self) -> CompletionsWithStreamingResponse: return CompletionsWithStreamingResponse(self) + @overload def create( self, *, @@ -53,7 +57,7 @@ def create( response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, seed: Optional[int] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[bool] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, @@ -67,6 +71,104 @@ def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: + ... + + @overload + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN, + functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, + seed: Optional[int] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Literal[True], + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, + tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: Optional[str] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[ChatCompletionChunk]: + ... + + @overload + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN, + functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, + seed: Optional[int] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: bool, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, + tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: Optional[str] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: + ... + + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN, + functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, + seed: Optional[int] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, + tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: Optional[str] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: """ Creates a model response for the given chat conversation. @@ -203,6 +305,8 @@ def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, + stream=stream or False, + stream_cls=Stream[ChatCompletionChunk], ) @@ -215,6 +319,7 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse: def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse: return AsyncCompletionsWithStreamingResponse(self) + @overload async def create( self, *, @@ -231,7 +336,7 @@ async def create( response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, seed: Optional[int] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[bool] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, @@ -245,6 +350,104 @@ async def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: + ... + + @overload + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN, + functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, + seed: Optional[int] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Literal[True], + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, + tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: Optional[str] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[ChatCompletionChunk]: + ... + + @overload + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN, + functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, + seed: Optional[int] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: bool, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, + tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: Optional[str] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: + ... + + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + function_call: Optional[completion_create_params.FunctionCall] | NotGiven = NOT_GIVEN, + functions: Optional[Iterable[completion_create_params.Function]] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: Optional[completion_create_params.ResponseFormat] | NotGiven = NOT_GIVEN, + seed: Optional[int] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tool_choice: Optional[completion_create_params.ToolChoice] | NotGiven = NOT_GIVEN, + tools: Optional[Iterable[completion_create_params.Tool]] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: Optional[str] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: """ Creates a model response for the given chat conversation. @@ -381,6 +584,8 @@ async def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, + stream=stream or False, + stream_cls=AsyncStream[ChatCompletionChunk], ) diff --git a/tests/api_resources/chat/test_completions.py b/tests/api_resources/chat/test_completions.py index 96e5cd9..4f0945d 100644 --- a/tests/api_resources/chat/test_completions.py +++ b/tests/api_resources/chat/test_completions.py @@ -89,7 +89,7 @@ def test_method_create_with_all_params(self, client: Groq) -> None: response_format={"type": "string"}, seed=0, stop="\n", - stream=True, + stream=False, temperature=0, tool_choice="none", tools=[ @@ -252,7 +252,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGroq) -> N response_format={"type": "string"}, seed=0, stop="\n", - stream=True, + stream=False, temperature=0, tool_choice="none", tools=[