Unverified Commit 3e63abd3 authored by Yeuoly's avatar Yeuoly Committed by GitHub

Feat/json mode (#2563)

parent 0620fa30
...@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { ...@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
'min': 1, 'min': 1,
'max': 2048, 'max': 2048,
'precision': 0, 'precision': 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
},
'required': False,
'options': ['JSON', 'XML'],
} }
} }
\ No newline at end of file
...@@ -91,6 +91,7 @@ class DefaultParameterName(Enum): ...@@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
PRESENCE_PENALTY = "presence_penalty" PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty" FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens" MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
@classmethod @classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName': def value_of(cls, value: Any) -> 'DefaultParameterName':
......
...@@ -262,23 +262,23 @@ class AIModel(ABC): ...@@ -262,23 +262,23 @@ class AIModel(ABC):
try: try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max: if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max'] parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min: if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min'] parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.precision: if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default'] parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision: if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision'] parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required: if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required'] parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help: if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject( parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'], en_US=default_parameter_rule['help']['en_US'],
) )
if not parameter_rule.help.en_US: if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans: if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError: except ValueError:
pass pass
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '8.00' input: '8.00'
output: '24.00' output: '24.00'
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '8.00' input: '8.00'
output: '24.00' output: '24.00'
......
...@@ -26,6 +26,8 @@ parameter_rules: ...@@ -26,6 +26,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '1.63' input: '1.63'
output: '5.51' output: '5.51'
......
...@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream ...@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params from anthropic.types import Completion, completion_create_params
from httpx import Timeout from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import ( ...@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class AnthropicLargeLanguageModel(LargeLanguageModel): <instructions>
{{instructions}}
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
...@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
self._transform_json_prompts(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(
content=f"```{response_format}\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 2048 default: 2048
min: 1 min: 1
max: 2048 max: 2048
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'
......
...@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
...@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ...@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.0005' input: '0.0005'
output: '0.0015' output: '0.0015'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.0015' input: '0.0015'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.001' input: '0.001'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 16385 max: 16385
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.003' input: '0.003'
output: '0.004' output: '0.004'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 16385 max: 16385
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.003' input: '0.003'
output: '0.004' output: '0.004'
......
...@@ -21,6 +21,8 @@ parameter_rules: ...@@ -21,6 +21,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.0015' input: '0.0015'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.001' input: '0.001'
output: '0.002' output: '0.002'
......
...@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio ...@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI ...@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
""" """
...@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ...@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
user=user user=user
) )
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
# handle fine tune remote models
base_model = model
if model.startswith('ft:'):
base_model = model.split(':')[1]
# get model mode
model_mode = self.get_model_mode(base_model, credentials)
# transform response format
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
stop = stop or []
if model_mode == LLMMode.CHAT:
# chat model
self._transform_chat_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
else:
self._transform_completion_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def _transform_completion_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# override the last user message
user_message = None
for i in range(len(prompt_messages) - 1, -1, -1):
if isinstance(prompt_messages[i], UserPromptMessage):
user_message = prompt_messages[i]
break
if user_message:
if prompt_messages[i].content[-11:] == 'Assistant: ':
# now we are in the chat app, remove the last assistant message
prompt_messages[i].content = prompt_messages[i].content[:-11]
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
else:
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"\n```{response_format}\n"
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
......
...@@ -13,6 +13,7 @@ from dashscope.common.error import ( ...@@ -13,6 +13,7 @@ from dashscope.common.error import (
) )
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
-> LLMResult | Generator:
"""
Wrapper for code block mode
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
mode = self.get_model_mode(model, credentials)
if mode == LLMMode.CHAT:
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
else:
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=response
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
...@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
""" """
extra_model_kwargs = {} extra_model_kwargs = {}
if stop: if stop:
extra_model_kwargs['stop_sequences'] = stop extra_model_kwargs['stop'] = stop
# transform credentials to kwargs for model instance # transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
...@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
params = { params = {
'model': model, 'model': model,
**model_parameters, **model_parameters,
**credentials_kwargs **credentials_kwargs,
**extra_model_kwargs,
} }
mode = self.get_model_mode(model, credentials) mode = self.get_model_mode(model, credentials)
......
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -56,6 +56,8 @@ parameter_rules: ...@@ -56,6 +56,8 @@ parameter_rules:
help: help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.02' input: '0.02'
output: '0.02' output: '0.02'
......
...@@ -57,6 +57,8 @@ parameter_rules: ...@@ -57,6 +57,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.008' input: '0.008'
output: '0.008' output: '0.008'
......
...@@ -25,6 +25,8 @@ parameter_rules: ...@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search - name: disable_search
label: label:
zh_Hans: 禁用搜索 zh_Hans: 禁用搜索
......
...@@ -25,6 +25,8 @@ parameter_rules: ...@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search - name: disable_search
label: label:
zh_Hans: 禁用搜索 zh_Hans: 禁用搜索
......
...@@ -25,3 +25,5 @@ parameter_rules: ...@@ -25,3 +25,5 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
...@@ -34,3 +34,5 @@ parameter_rules: ...@@ -34,3 +34,5 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。 zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search. en_US: Disable the model to perform external search.
required: false required: false
- name: response_format
use_template: response_format
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Optional, Union, cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( ...@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
RateLimitReachedError, RateLimitReachedError,
) )
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class ErnieBotLarguageModel(LargeLanguageModel): <instructions>
{{instructions}}
</instructions>
You should also complete the text started with ``` but not tell ``` directly.
"""
class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
...@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel): ...@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel):
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
response_format = model_parameters['response_format']
stop = stop or []
self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format)
model_parameters.pop('response_format')
if stream:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
)
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts to model prompts
"""
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += "\n```JSON\n{\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content="```JSON\n{\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int: tools: list[PromptMessageTool] | None = None) -> int:
# tools is not supported yet # tools is not supported yet
......
...@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp ...@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
And you should always end the block with a "```" to indicate the end of the JSON object.
<instructions>
{{instructions}}
</instructions>
```JSON"""
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
...@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model # invoke model
# stop = stop or []
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
# def _transform_json_prompts(self, model: str, credentials: dict,
# prompt_messages: list[PromptMessage], model_parameters: dict,
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
# stream: bool = True, user: str | None = None) \
# -> None:
# """
# Transform json prompts to model prompts
# """
# if "}\n\n" not in stop:
# stop.append("}\n\n")
# # check if there is a system message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# # override the system message
# prompt_messages[0] = SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
# )
# else:
# # insert the system message
# prompt_messages.insert(0, SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
# ))
# # check if the last message is a user message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# # add ```JSON\n to the last message
# prompt_messages[-1].content += "\n```JSON\n"
# else:
# # append a user message
# prompt_messages.append(UserPromptMessage(
# content="```JSON\n"
# ))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
...@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
""" """
extra_model_kwargs = {} extra_model_kwargs = {}
if stop: if stop:
extra_model_kwargs['stop_sequences'] = stop extra_model_kwargs['stop'] = stop
client = ZhipuAI( client = ZhipuAI(
api_key=credentials_kwargs['api_key'] api_key=credentials_kwargs['api_key']
...@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
] ]
if stream: if stream:
response = client.chat.completions.create(stream=stream, **params) response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs)
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
response = client.chat.completions.create(**params) response = client.chat.completions.create(**params, **extra_model_kwargs)
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
def _handle_generate_response(self, model: str, def _handle_generate_response(self, model: str,
......
...@@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, ...@@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel
def test_predefined_models(): def test_predefined_models():
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
model_schemas = model.predefined_models() model_schemas = model.predefined_models()
assert len(model_schemas) >= 1 assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity) assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model(): def test_validate_credentials_for_chat_model():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials( model.validate_credentials(
...@@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model(): ...@@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model():
def test_invoke_model_ernie_bot(): def test_invoke_model_ernie_bot():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot(): ...@@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot():
def test_invoke_model_ernie_bot_turbo(): def test_invoke_model_ernie_bot_turbo():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot-turbo', model='ernie-bot-turbo',
...@@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo(): ...@@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo():
def test_invoke_model_ernie_8k(): def test_invoke_model_ernie_8k():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot-8k', model='ernie-bot-8k',
...@@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k(): ...@@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k():
def test_invoke_model_ernie_bot_4(): def test_invoke_model_ernie_bot_4():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot-4', model='ernie-bot-4',
...@@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4(): ...@@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4():
def test_invoke_stream_model(): def test_invoke_stream_model():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -182,7 +182,7 @@ def test_invoke_stream_model(): ...@@ -182,7 +182,7 @@ def test_invoke_stream_model():
def test_invoke_model_with_system(): def test_invoke_model_with_system():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -212,7 +212,7 @@ def test_invoke_model_with_system(): ...@@ -212,7 +212,7 @@ def test_invoke_model_with_system():
def test_invoke_with_search(): def test_invoke_with_search():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -250,7 +250,7 @@ def test_invoke_with_search(): ...@@ -250,7 +250,7 @@ def test_invoke_with_search():
def test_get_num_tokens(): def test_get_num_tokens():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.get_num_tokens( response = model.get_num_tokens(
model='ernie-bot', model='ernie-bot',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment