Unverified Commit 77f9e8ce authored by Chenhe Gu's avatar Chenhe Gu Committed by GitHub

add example api url endpoint in placeholder (#1887)

Co-authored-by: 's avatartakatost <takatost@gmail.com>
parent 5ca4c4a4
import logging import logging
from decimal import Decimal from decimal import Decimal
from urllib.parse import urljoin
import requests import requests
import json import json
...@@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast ...@@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \ from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage AssistantPromptMessage, PromptMessageContent, \
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \ PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
DefaultParameterName, \
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials['endpoint_url'] endpoint_url = credentials['endpoint_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model # prepare the payload for a simple ping to the model
data = { data = {
...@@ -105,8 +111,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -105,8 +111,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"content": "ping" "content": "ping"
}, },
] ]
endpoint_url = urljoin(endpoint_url, 'chat/completions')
elif completion_type is LLMMode.COMPLETION: elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping' data['prompt'] = 'ping'
endpoint_url = urljoin(endpoint_url, 'completions')
else: else:
raise ValueError("Unsupported completion type for model configuration.") raise ValueError("Unsupported completion type for model configuration.")
...@@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
) )
if response.status_code != 200: if response.status_code != 200:
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}') raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if (completion_type is LLMMode.CHAT
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
elif (completion_type is LLMMode.COMPLETION
and ('object' not in json_result or json_result['object'] != 'text_completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
except CredentialsValidateFailedError:
raise
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
...@@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM, model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MODE: 'chat' ModelPropertyKey.MODE: credentials.get('mode'),
}, },
parameter_rules=[ parameter_rules=[
ParameterRule( ParameterRule(
...@@ -197,10 +221,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -197,10 +221,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return entity return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \ tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]: user: Optional[str] = None) -> Union[LLMResult, Generator]:
""" """
Invoke llm completion model Invoke llm completion model
...@@ -223,6 +247,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -223,6 +247,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"] endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith('/'):
endpoint_url += '/'
data = { data = {
"model": model, "model": model,
...@@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
completion_type = LLMMode.value_of(credentials['mode']) completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION: elif completion_type == LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content data['prompt'] = prompt_messages[0].content
else: else:
raise ValueError("Unsupported completion type for model configuration.") raise ValueError("Unsupported completion type for model configuration.")
...@@ -245,7 +273,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -245,7 +273,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
data["tool_choice"] = "auto" data["tool_choice"] = "auto"
for tool in tools: for tool in tools:
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool))) formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools data["tools"] = formatted_tools
...@@ -313,6 +341,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -313,6 +341,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk: if chunk:
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip() decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
chunk_json = None
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json = json.loads(decoded_chunk)
# stream ended # stream ended
...@@ -323,13 +352,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -323,13 +352,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
finish_reason="Non-JSON encountered." finish_reason="Non-JSON encountered."
) )
if len(chunk_json['choices']) == 0: if not chunk_json or len(chunk_json['choices']) == 0:
continue continue
delta = chunk_json['choices'][0]['delta'] choice = chunk_json['choices'][0]
chunk_index = chunk_json['choices'][0]['index'] chunk_index = choice['index'] if 'index' in choice else chunk_index
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''): if 'delta' in choice:
delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '':
continue continue
assistant_message_tool_calls = delta.get('tool_calls', None) assistant_message_tool_calls = delta.get('tool_calls', None)
...@@ -348,16 +379,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -348,16 +379,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
) )
full_assistant_content += delta.get('content', '') full_assistant_content += delta.get('content', '')
elif 'text' in choice:
if choice.get('text') is None or choice.get('text') == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=choice.get('text', '')
)
full_assistant_content += choice.get('text', '')
else:
continue
# check payload indicator for completion # check payload indicator for completion
if chunk_json['choices'][0].get('finish_reason') is not None: if chunk_json['choices'][0].get('finish_reason') is not None:
yield create_final_llm_result_chunk( yield create_final_llm_result_chunk(
index=chunk_index, index=chunk_index,
message=assistant_prompt_message, message=assistant_prompt_message,
finish_reason=chunk_json['choices'][0]['finish_reason'] finish_reason=chunk_json['choices'][0]['finish_reason']
) )
else: else:
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
...@@ -374,6 +415,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -374,6 +415,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
finish_reason="End of stream." finish_reason="End of stream."
) )
chunk_index += 1
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult: prompt_messages: list[PromptMessage]) -> LLMResult:
...@@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ...@@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls: if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls] message.tool_calls]
# function_call = message.tool_calls[0] # function_call = message.tool_calls[0]
# message_dict["function_call"] = { # message_dict["function_call"] = {
......
...@@ -33,8 +33,8 @@ model_credential_schema: ...@@ -33,8 +33,8 @@ model_credential_schema:
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入您的 API endpoint URL zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Enter your API endpoint URL en_US: Base URL, eg. https://api.openai.com/v1
- variable: mode - variable: mode
show_on: show_on:
- variable: __model_type - variable: __model_type
......
import time import time
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import Optional
from urllib.parse import urljoin
import requests import requests
import json import json
...@@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url'] endpoint_url = urljoin(endpoint_url, 'embeddings')
extra_model_kwargs = {} extra_model_kwargs = {}
if user: if user:
...@@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url'] endpoint_url = urljoin(endpoint_url, 'embeddings')
payload = { payload = {
'input': 'ping', 'input': 'ping',
...@@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
) )
if response.status_code != 200: if response.status_code != 200:
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}") raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if 'model' not in json_result:
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response')
except CredentialsValidateFailedError:
raise
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))
...@@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): ...@@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1, ModelPropertyKey.MAX_CHUNKS: 1,
}, },
parameter_rules=[], parameter_rules=[],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment