Unverified Commit ede69b46 authored by takatost's avatar takatost Committed by GitHub

fix: gemini block error (#1877)

Co-authored-by: 's avatarchenhe <guchenhe@gmail.com>
parent 61aaeff4
...@@ -132,8 +132,8 @@ class LargeLanguageModel(AIModel): ...@@ -132,8 +132,8 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None system_fingerprint = None
real_model = model real_model = model
for chunk in result: try:
try: for chunk in result:
yield chunk yield chunk
self._trigger_new_chunk_callbacks( self._trigger_new_chunk_callbacks(
...@@ -156,8 +156,8 @@ class LargeLanguageModel(AIModel): ...@@ -156,8 +156,8 @@ class LargeLanguageModel(AIModel):
if chunk.system_fingerprint: if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint system_fingerprint = chunk.system_fingerprint
except Exception as e: except Exception as e:
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
self._trigger_after_invoke_callbacks( self._trigger_after_invoke_callbacks(
model=model, model=model,
......
...@@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List ...@@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List
import google.generativeai as genai import google.generativeai as genai
import google.api_core.exceptions as exceptions import google.api_core.exceptions as exceptions
import google.generativeai.client as client import google.generativeai.client as client
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from google.generativeai.types import GenerateContentResponse, ContentType from google.generativeai.types import GenerateContentResponse, ContentType
from google.generativeai.types.content_types import to_part from google.generativeai.types.content_types import to_part
...@@ -124,7 +125,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ...@@ -124,7 +125,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
last_msg = prompt_messages[-1] last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg) content = self._format_message_to_glm_content(last_msg)
history.append(content) history.append(content)
else: else:
for msg in prompt_messages: # makes message roles strictly alternating for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg) content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]: if history and history[-1]["role"] == content["role"]:
...@@ -139,13 +140,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ...@@ -139,13 +140,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
new_custom_client = new_client_manager.make_client("generative") new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client google_model._client = new_custom_client
safety_settings={
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
response = google_model.generate_content( response = google_model.generate_content(
contents=history, contents=history,
generation_config=genai.types.GenerationConfig( generation_config=genai.types.GenerationConfig(
**config_kwargs **config_kwargs
), ),
stream=stream stream=stream,
safety_settings=safety_settings
) )
if stream: if stream:
...@@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ...@@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content=response.text content=response.text
) )
# calculate num tokens # calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
...@@ -202,11 +210,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ...@@ -202,11 +210,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for chunk in response: for chunk in response:
content = chunk.text content = chunk.text
index += 1 index += 1
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=content if content else '', content=content if content else '',
) )
if not response._done: if not response._done:
# transform assistant message to prompt message # transform assistant message to prompt message
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment