Unverified Commit 77f9e8ce authored by Chenhe Gu's avatar Chenhe Gu Committed by GitHub

add example api url endpoint in placeholder (#1887)

Co-authored-by: 's avatartakatost <takatost@gmail.com>
parent 5ca4c4a4
import logging import logging
from decimal import Decimal from decimal import Decimal
from urllib.parse import urljoin
import requests import requests
import json import json
...@@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast ...@@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \ from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage AssistantPromptMessage, PromptMessageContent, \
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \ PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
DefaultParameterName, \
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -70,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -70,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
:return: :return:
""" """
return self._num_tokens_from_messages(model, prompt_messages, tools) return self._num_tokens_from_messages(model, prompt_messages, tools)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
...@@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials['endpoint_url'] endpoint_url = credentials['endpoint_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model # prepare the payload for a simple ping to the model
data = { data = {
...@@ -105,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -105,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"content": "ping" "content": "ping"
}, },
] ]
endpoint_url = urljoin(endpoint_url, 'chat/completions')
elif completion_type is LLMMode.COMPLETION: elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping' data['prompt'] = 'ping'
endpoint_url = urljoin(endpoint_url, 'completions')
else: else:
raise ValueError("Unsupported completion type for model configuration.") raise ValueError("Unsupported completion type for model configuration.")
# send a post request to validate the credentials # send a post request to validate the credentials
response = requests.post( response = requests.post(
endpoint_url, endpoint_url,
...@@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
) )
if response.status_code != 200: if response.status_code != 200:
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}') raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if (completion_type is LLMMode.CHAT
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
elif (completion_type is LLMMode.COMPLETION
and ('object' not in json_result or json_result['object'] != 'text_completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
except CredentialsValidateFailedError:
raise
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
...@@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM, model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MODE: 'chat' ModelPropertyKey.MODE: credentials.get('mode'),
}, },
parameter_rules=[ parameter_rules=[
ParameterRule( ParameterRule(
...@@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return entity return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
user: Optional[str] = None) -> Union[LLMResult, Generator]: stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
""" """
Invoke llm completion model Invoke llm completion model
...@@ -223,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -223,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"] endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith('/'):
endpoint_url += '/'
data = { data = {
"model": model, "model": model,
"stream": stream, "stream": stream,
...@@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
completion_type = LLMMode.value_of(credentials['mode']) completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION: elif completion_type == LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content data['prompt'] = prompt_messages[0].content
else: else:
raise ValueError("Unsupported completion type for model configuration.") raise ValueError("Unsupported completion type for model configuration.")
...@@ -245,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -245,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
data["tool_choice"] = "auto" data["tool_choice"] = "auto"
for tool in tools: for tool in tools:
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool))) formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools data["tools"] = formatted_tools
if stop: if stop:
...@@ -254,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -254,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if user: if user:
data["user"] = user data["user"] = user
response = requests.post( response = requests.post(
endpoint_url, endpoint_url,
headers=headers, headers=headers,
...@@ -275,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -275,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator: prompt_messages: list[PromptMessage]) -> Generator:
""" """
Handle llm stream response Handle llm stream response
...@@ -313,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -313,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk: if chunk:
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip() decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
chunk_json = None
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json = json.loads(decoded_chunk)
# stream ended # stream ended
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
index=chunk_index + 1, index=chunk_index + 1,
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered." finish_reason="Non-JSON encountered."
) )
if len(chunk_json['choices']) == 0: if not chunk_json or len(chunk_json['choices']) == 0:
continue continue
delta = chunk_json['choices'][0]['delta'] choice = chunk_json['choices'][0]
chunk_index = chunk_json['choices'][0]['index'] chunk_index = choice['index'] if 'index' in choice else chunk_index
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''): if 'delta' in choice:
continue delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '':
assistant_message_tool_calls = delta.get('tool_calls', None) continue
# assistant_message_function_call = delta.delta.function_call
assistant_message_tool_calls = delta.get('tool_calls', None)
# extract tool calls from response # assistant_message_function_call = delta.delta.function_call
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) # extract tool calls from response
# function_call = self._extract_response_function_call(assistant_message_function_call) if assistant_message_tool_calls:
# tool_calls = [function_call] if function_call else [] tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# transform assistant message to prompt message # tool_calls = [function_call] if function_call else []
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''), # transform assistant message to prompt message
tool_calls=tool_calls if assistant_message_tool_calls else [] assistant_prompt_message = AssistantPromptMessage(
) content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '') full_assistant_content += delta.get('content', '')
elif 'text' in choice:
if choice.get('text') is None or choice.get('text') == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=choice.get('text', '')
)
full_assistant_content += choice.get('text', '')
else:
continue
# check payload indicator for completion # check payload indicator for completion
if chunk_json['choices'][0].get('finish_reason') is not None: if chunk_json['choices'][0].get('finish_reason') is not None:
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
index=chunk_index, index=chunk_index,
message=assistant_prompt_message, message=assistant_prompt_message,
finish_reason=chunk_json['choices'][0]['finish_reason'] finish_reason=chunk_json['choices'][0]['finish_reason']
) )
else: else:
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
...@@ -373,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -373,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
finish_reason="End of stream." finish_reason="End of stream."
) )
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, chunk_index += 1
prompt_messages: list[PromptMessage]) -> LLMResult:
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:
response_json = response.json() response_json = response.json()
completion_type = LLMMode.value_of(credentials['mode']) completion_type = LLMMode.value_of(credentials['mode'])
...@@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls: if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls] message.tool_calls]
# function_call = message.tool_calls[0] # function_call = message.tool_calls[0]
# message_dict["function_call"] = { # message_dict["function_call"] = {
...@@ -484,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -484,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message_dict["name"] = message.name message_dict["name"] = message.name
return message_dict return message_dict
def _num_tokens_from_string(self, model: str, text: str, def _num_tokens_from_string(self, model: str, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
...@@ -507,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -507,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
""" """
Approximate num tokens with GPT2 tokenizer. Approximate num tokens with GPT2 tokenizer.
""" """
tokens_per_message = 3 tokens_per_message = 3
tokens_per_name = 1 tokens_per_name = 1
num_tokens = 0 num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict: for message in messages_dict:
...@@ -599,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -599,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
num_tokens += self._get_num_tokens_by_gpt2(required_field) num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens return num_tokens
def _extract_response_tool_calls(self, def _extract_response_tool_calls(self,
response_tool_calls: list[dict]) \ response_tool_calls: list[dict]) \
-> list[AssistantPromptMessage.ToolCall]: -> list[AssistantPromptMessage.ToolCall]:
......
...@@ -33,8 +33,8 @@ model_credential_schema: ...@@ -33,8 +33,8 @@ model_credential_schema:
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入您的 API endpoint URL zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Enter your API endpoint URL en_US: Base URL, eg. https://api.openai.com/v1
- variable: mode - variable: mode
show_on: show_on:
- variable: __model_type - variable: __model_type
......
import time import time
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import Optional
from urllib.parse import urljoin
import requests import requests
import json import json
...@@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url'] endpoint_url = urljoin(endpoint_url, 'embeddings')
extra_model_kwargs = {} extra_model_kwargs = {}
if user: if user:
...@@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url'] endpoint_url = urljoin(endpoint_url, 'embeddings')
payload = { payload = {
'input': 'ping', 'input': 'ping',
...@@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
) )
if response.status_code != 200: if response.status_code != 200:
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}") raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if 'model' not in json_result:
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response')
except CredentialsValidateFailedError:
raise
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
...@@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1, ModelPropertyKey.MAX_CHUNKS: 1,
}, },
parameter_rules=[], parameter_rules=[],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment