Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai: raw response headers #24150

Merged
merged 14 commits into from
Jul 16, 2024
6 changes: 4 additions & 2 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,9 @@ def _get_ls_params(
return params

def _create_chat_result(
self, response: Union[dict, openai.BaseModel]
self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
) -> ChatResult:
if not isinstance(response, dict):
response = response.model_dump()
Expand All @@ -938,7 +940,7 @@ def _create_chat_result(
"Azure has not provided the response due to a content filter "
"being triggered"
)
chat_result = super()._create_chat_result(response)
chat_result = super()._create_chat_result(response, generation_info)

if "model" in response:
model = response["model"]
Expand Down
58 changes: 48 additions & 10 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ class BaseChatOpenAI(BaseChatModel):
extra_body: Optional[Mapping[str, Any]] = None
"""Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM."""
include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata."""

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -510,7 +512,9 @@ def _stream(
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
with self.client.create(**payload) as response:
with self.client.with_raw_response.create(**payload) as raw_response:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tradeoff: this code is more maintainable, but it won't tolerate custom client/async_client that doesn't implement with_raw_response (i.e. non-openai clients)

Could make it tolerate this by only raw-responsing if include_response_headers is set to true.

I'm generally in favor of maintainability here. Happy to discuss.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switching back to not break things

response = raw_response.parse()
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
Expand All @@ -536,7 +540,11 @@ def _stream(
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
generation_info = (
{"headers": dict(raw_response.headers)}
if self.include_response_headers and is_first_chunk
else {}
)
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
Expand All @@ -555,6 +563,7 @@ def _stream(
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
yield generation_chunk

def _generate(
Expand All @@ -570,8 +579,16 @@ def _generate(
)
return generate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = self.client.create(**payload)
return self._create_chat_result(response)
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
return self._create_chat_result(
response,
(
{"headers": dict(raw_response.headers)}
if self.include_response_headers
else None
),
)

def _get_request_payload(
self,
Expand All @@ -590,7 +607,9 @@ def _get_request_payload(
}

def _create_chat_result(
self, response: Union[dict, openai.BaseModel]
self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
) -> ChatResult:
generations = []
if not isinstance(response, dict):
Expand All @@ -612,7 +631,9 @@ def _create_chat_result(
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
generation_info = dict(finish_reason=res.get("finish_reason"))
generation_info = dict(
finish_reason=res.get("finish_reason"), **(generation_info or {})
)
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(message=message, generation_info=generation_info)
Expand All @@ -634,8 +655,10 @@ async def _astream(
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
response = await self.async_client.create(**payload)
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
async with response:
is_first_chunk = True
async for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
Expand Down Expand Up @@ -664,7 +687,11 @@ async def _astream(
choice["delta"],
default_chunk_class,
)
generation_info = {}
generation_info = (
{"headers": dict(raw_response.headers)}
if self.include_response_headers and is_first_chunk
else {}
)
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
Expand All @@ -685,6 +712,7 @@ async def _astream(
chunk=generation_chunk,
logprobs=logprobs,
)
is_first_chunk = False
yield generation_chunk

async def _agenerate(
Expand All @@ -700,8 +728,18 @@ async def _agenerate(
)
return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = await self.async_client.create(**payload)
return await run_in_executor(None, self._create_chat_result, response)
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
return await run_in_executor(
None,
self._create_chat_result,
response,
(
{"headers": dict(raw_response.headers)}
if self.include_response_headers
else None
),
)

@property
def _identifying_params(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ def test_openai_invoke() -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)

# assert no response headers if include_response_headers is not set
assert "headers" not in result.response_metadata


def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
Expand Down Expand Up @@ -671,3 +674,13 @@ def test_openai_proxy() -> None:
assert proxy.scheme == b"http"
assert proxy.host == b"localhost"
assert proxy.port == 8080


def test_openai_response_headers_invoke() -> None:
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True)
result = chat_openai.invoke("I'm Pickle Rick")
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers
67 changes: 41 additions & 26 deletions libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,38 +189,56 @@ def mock_completion() -> dict:
}


def test_openai_invoke(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = MagicMock()
completed = False
@pytest.fixture
def mock_client(mock_completion: dict) -> MagicMock:
rtn = MagicMock()

mock_create = MagicMock()

mock_resp = MagicMock()
mock_resp.headers = {"content-type": "application/json"}
mock_resp.parse.return_value = mock_completion
mock_create.return_value = mock_resp

rtn.with_raw_response.create = mock_create
return rtn


@pytest.fixture
def mock_async_client(mock_completion: dict) -> AsyncMock:
rtn = AsyncMock()

def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_create = AsyncMock()
mock_resp = MagicMock()
mock_resp.parse.return_value = mock_completion
mock_create.return_value = mock_resp

rtn.with_raw_response.create = mock_create
return rtn


def test_openai_invoke(mock_client: MagicMock) -> None:
llm = ChatOpenAI()

mock_client.create = mock_create
with patch.object(llm, "client", mock_client):
res = llm.invoke("bar")
assert res.content == "Bar Baz"
assert completed

# headers are not in response_metadata if include_response_headers not set
assert "headers" not in res.response_metadata
assert mock_client.with_raw_response.create.called

async def test_openai_ainvoke(mock_completion: dict) -> None:
llm = ChatOpenAI()
mock_client = AsyncMock()
completed = False

async def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None:
llm = ChatOpenAI()

mock_client.create = mock_create
with patch.object(llm, "async_client", mock_client):
with patch.object(llm, "async_client", mock_async_client):
res = await llm.ainvoke("bar")
assert res.content == "Bar Baz"
assert completed

# headers are not in response_metadata if include_response_headers not set
assert "headers" not in res.response_metadata
assert mock_async_client.with_raw_response.create.called


@pytest.mark.parametrize(
Expand All @@ -239,16 +257,13 @@ def test__get_encoding_model(model: str) -> None:
return


def test_openai_invoke_name(mock_completion: dict) -> None:
def test_openai_invoke_name(mock_client: MagicMock) -> None:
llm = ChatOpenAI()

mock_client = MagicMock()
mock_client.create.return_value = mock_completion

with patch.object(llm, "client", mock_client):
messages = [HumanMessage(content="Foo", name="Katie")]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
call_args, call_kwargs = mock_client.with_raw_response.create.call_args
assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"]
assert len(call_messages) == 1
Expand Down
Loading