Unverified Commit 5e97eb18 authored by takatost's avatar takatost Committed by GitHub

fix: azure openai stream response usage missing (#1998)

parent c9e4147b
...@@ -257,6 +257,9 @@ class AppRunner: ...@@ -257,6 +257,9 @@ class AppRunner:
if not usage and result.delta.usage: if not usage and result.delta.usage:
usage = result.delta.usage usage = result.delta.usage
if not usage:
usage = LLMUsage.empty_usage()
llm_result = LLMResult( llm_result = LLMResult(
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
......
...@@ -322,8 +322,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -322,8 +322,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
response: Stream[ChatCompletionChunk], response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> Generator: tools: Optional[list[PromptMessageTool]] = None) -> Generator:
index = 0
full_assistant_content = '' full_assistant_content = ''
real_model = model
system_fingerprint = None
completion = ''
for chunk in response: for chunk in response:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
...@@ -349,13 +352,27 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -349,13 +352,27 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
full_assistant_content += delta.delta.content if delta.delta.content else '' full_assistant_content += delta.delta.content if delta.delta.content else ''
if delta.finish_reason is not None: real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content if delta.delta.content else ''
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
)
index += 0
# calculate num tokens # calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
full_assistant_prompt_message = AssistantPromptMessage( full_assistant_prompt_message = AssistantPromptMessage(
content=full_assistant_content, content=completion
tool_calls=tool_calls
) )
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
...@@ -363,26 +380,16 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ...@@ -363,26 +380,16 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk( yield LLMResultChunk(
model=chunk.model, model=real_model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint, system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=delta.index, index=index,
message=assistant_prompt_message, message=AssistantPromptMessage(content=''),
finish_reason=delta.finish_reason, finish_reason='stop',
usage=usage usage=usage
) )
) )
else:
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
)
)
@staticmethod @staticmethod
def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \
......
...@@ -190,7 +190,6 @@ def test_invoke_stream_chat_model(setup_openai_mock): ...@@ -190,7 +190,6 @@ def test_invoke_stream_chat_model(setup_openai_mock):
assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage) assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None: if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0 assert chunk.delta.usage.completion_tokens > 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment