Unverified Commit cca9edc9 authored by takatost's avatar takatost Committed by GitHub

feat: ollama support (#2003)

parent 5e75f702
......@@ -459,10 +459,33 @@ class GenerateTaskPipeline:
"files": files
})
else:
prompts.append({
"role": 'user',
"text": prompt_messages[0].content
prompt_message = prompt_messages[0]
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
params = {
"role": 'user',
"text": text,
}
if files:
params['files'] = files
prompts.append(params)
return prompts
......
......@@ -6,6 +6,7 @@
- huggingface_hub
- cohere
- togetherai
- ollama
- zhipuai
- baichuan
- spark
......
......@@ -54,5 +54,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入LocalAI的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your LocalAI, for example https://example.com/xxx
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
<svg width="82" height="24" viewBox="0 0 82 24" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<rect x="1" width="16.9688" height="24" fill="url(#pattern0)"/>
<path d="M71.4453 14.552C71.4453 13.6667 71.6266 12.8827 71.9893 12.2C72.3626 11.5174 72.864 10.9894 73.4933 10.616C74.1333 10.232 74.8373 10.04 75.6053 10.04C76.2986 10.04 76.9013 10.1787 77.4133 10.456C77.936 10.7227 78.352 11.0587 78.6613 11.464V10.184H80.5013V19H78.6613V17.688C78.352 18.104 77.9306 18.4507 77.3973 18.728C76.864 19.0054 76.256 19.144 75.5733 19.144C74.816 19.144 74.1226 18.952 73.4933 18.568C72.864 18.1734 72.3626 17.6294 71.9893 16.936C71.6266 16.232 71.4453 15.4374 71.4453 14.552ZM78.6613 14.584C78.6613 13.976 78.5333 13.448 78.2773 13C78.032 12.552 77.7066 12.2107 77.3013 11.976C76.896 11.7414 76.4586 11.624 75.9893 11.624C75.52 11.624 75.0826 11.7414 74.6773 11.976C74.272 12.2 73.9413 12.536 73.6853 12.984C73.44 13.4214 73.3173 13.944 73.3173 14.552C73.3173 15.16 73.44 15.6934 73.6853 16.152C73.9413 16.6107 74.272 16.9627 74.6773 17.208C75.0933 17.4427 75.5306 17.56 75.9893 17.56C76.4586 17.56 76.896 17.4427 77.3013 17.208C77.7066 16.9734 78.032 16.632 78.2773 16.184C78.5333 15.7254 78.6613 15.192 78.6613 14.584Z" fill="black"/>
<path d="M66.42 10.04C67.1134 10.04 67.732 10.184 68.276 10.472C68.8307 10.76 69.2627 11.1867 69.572 11.752C69.892 12.3174 70.052 13 70.052 13.8V19H68.244V14.072C68.244 13.2827 68.0467 12.68 67.652 12.264C67.2574 11.8374 66.7187 11.624 66.036 11.624C65.3534 11.624 64.8094 11.8374 64.404 12.264C64.0094 12.68 63.812 13.2827 63.812 14.072V19H62.004V14.072C62.004 13.2827 61.8067 12.68 61.412 12.264C61.0174 11.8374 60.4787 11.624 59.796 11.624C59.1134 11.624 58.5694 11.8374 58.164 12.264C57.7694 12.68 57.572 13.2827 57.572 14.072V19H55.748V10.184H57.572V11.192C57.8707 10.8294 58.2494 10.5467 58.708 10.344C59.1667 10.1414 59.6574 10.04 60.18 10.04C60.884 10.04 61.5134 10.1894 62.068 10.488C62.6227 10.7867 63.0494 11.2187 63.348 11.784C63.6147 11.2507 64.0307 10.8294 64.596 10.52C65.1614 10.2 65.7694 10.04 66.42 10.04Z" fill="black"/>
<path d="M44.6152 14.552C44.6152 13.6667 44.7966 12.8827 45.1592 12.2C45.5326 11.5174 46.0339 10.9894 46.6632 10.616C47.3032 10.232 48.0072 10.04 48.7752 10.04C49.4686 10.04 50.0712 10.1787 50.5832 10.456C51.1059 10.7227 51.5219 11.0587 51.8312 11.464V10.184H53.6712V19H51.8312V17.688C51.5219 18.104 51.1006 18.4507 50.5672 18.728C50.0339 19.0054 49.4259 19.144 48.7432 19.144C47.9859 19.144 47.2926 18.952 46.6632 18.568C46.0339 18.1734 45.5326 17.6294 45.1592 16.936C44.7966 16.232 44.6152 15.4374 44.6152 14.552ZM51.8312 14.584C51.8312 13.976 51.7032 13.448 51.4472 13C51.2019 12.552 50.8766 12.2107 50.4712 11.976C50.0659 11.7414 49.6286 11.624 49.1592 11.624C48.6899 11.624 48.2526 11.7414 47.8472 11.976C47.4419 12.2 47.1112 12.536 46.8552 12.984C46.6099 13.4214 46.4872 13.944 46.4872 14.552C46.4872 15.16 46.6099 15.6934 46.8552 16.152C47.1112 16.6107 47.4419 16.9627 47.8472 17.208C48.2632 17.4427 48.7006 17.56 49.1592 17.56C49.6286 17.56 50.0659 17.4427 50.4712 17.208C50.8766 16.9734 51.2019 16.632 51.4472 16.184C51.7032 15.7254 51.8312 15.192 51.8312 14.584Z" fill="black"/>
<path d="M43.1502 7.16016V19.0002H41.3262V7.16016H43.1502Z" fill="black"/>
<path d="M39.2498 7.16016V19.0002H37.4258V7.16016H39.2498Z" fill="black"/>
<path d="M30.2718 19.1123C29.2371 19.1123 28.2825 18.8723 27.4078 18.3923C26.5438 17.9017 25.8558 17.2243 25.3438 16.3603C24.8425 15.4857 24.5918 14.5043 24.5918 13.4163C24.5918 12.3283 24.8425 11.3523 25.3438 10.4883C25.8558 9.62433 26.5438 8.95233 27.4078 8.47233C28.2825 7.98166 29.2371 7.73633 30.2718 7.73633C31.3171 7.73633 32.2718 7.98166 33.1358 8.47233C34.0105 8.95233 34.6985 9.62433 35.1998 10.4883C35.7011 11.3523 35.9518 12.3283 35.9518 13.4163C35.9518 14.5043 35.7011 15.4857 35.1998 16.3603C34.6985 17.2243 34.0105 17.9017 33.1358 18.3923C32.2718 18.8723 31.3171 19.1123 30.2718 19.1123ZM30.2718 17.5283C31.0078 17.5283 31.6638 17.363 32.2398 17.0323C32.8158 16.691 33.2638 16.211 33.5838 15.5923C33.9145 14.963 34.0798 14.2377 34.0798 13.4163C34.0798 12.595 33.9145 11.875 33.5838 11.2563C33.2638 10.6377 32.8158 10.163 32.2398 9.83233C31.6638 9.50166 31.0078 9.33633 30.2718 9.33633C29.5358 9.33633 28.8798 9.50166 28.3038 9.83233C27.7278 10.163 27.2745 10.6377 26.9438 11.2563C26.6238 11.875 26.4638 12.595 26.4638 13.4163C26.4638 14.2377 26.6238 14.963 26.9438 15.5923C27.2745 16.211 27.7278 16.691 28.3038 17.0323C28.8798 17.363 29.5358 17.5283 30.2718 17.5283Z" fill="black"/>
<defs>
<pattern id="pattern0" patternContentUnits="objectBoundingBox" width="1" height="1">
<use xlink:href="#image0_16324_59298" transform="scale(0.00552486 0.00390625)"/>
</pattern>
<image id="image0_16324_59298" width="181" height="256" xlink:href=""/>
</defs>
</svg>
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g clip-path="url(#clip0_16325_59237)">
<rect width="24" height="24" rx="5" fill="white"/>
<rect x="3.5" width="17" height="24" fill="url(#pattern0)"/>
</g>
<defs>
<pattern id="pattern0" patternContentUnits="objectBoundingBox" width="1" height="1">
<use xlink:href="#image0_16325_59237" transform="matrix(0.00552486 0 0 0.00391344 0 -0.00092081)"/>
</pattern>
<clipPath id="clip0_16325_59237">
<rect width="24" height="24" fill="white"/>
</clipPath>
<image id="image0_16325_59237" width="181" height="256" xlink:href=""/>
</defs>
</svg>
import json
import logging
import re
from decimal import Decimal
from typing import Optional, Generator, Union, List, cast
from urllib.parse import urljoin
import requests
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage, AssistantPromptMessage, \
UserPromptMessage, PromptMessageContentType, ImagePromptMessageContent, \
TextPromptMessageContent, SystemPromptMessage
from core.model_runtime.entities.model_entities import I18nObject, ModelType, \
PriceConfig, AIModelEntity, FetchFrom, ModelPropertyKey, ParameterRule, ParameterType, DefaultParameterName, \
ModelFeature
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, \
LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
class OllamaLargeLanguageModel(LargeLanguageModel):
"""
Model class for Ollama large language model.
"""
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user
)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if model_mode == LLMMode.CHAT:
# chat model
return self._num_tokens_from_messages(prompt_messages)
else:
first_prompt_message = prompt_messages[0]
if isinstance(first_prompt_message.content, str):
text = first_prompt_message.content
else:
text = ''
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
break
return self._get_num_tokens_by_gpt2(text)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._generate(
model=model,
credentials=credentials,
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={
'num_predict': 5
},
stream=False
)
except InvokeError as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
headers = {
'Content-Type': 'application/json'
}
endpoint_url = credentials['base_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model
data = {
'model': model,
'stream': stream
}
if 'format' in model_parameters:
data['format'] = model_parameters['format']
del model_parameters['format']
data['options'] = model_parameters or {}
if stop:
data['stop'] = "\n".join(stop)
completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'api/chat')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
else:
endpoint_url = urljoin(endpoint_url, 'api/generate')
first_prompt_message = prompt_messages[0]
if isinstance(first_prompt_message, UserPromptMessage):
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
if isinstance(first_prompt_message.content, str):
data['prompt'] = first_prompt_message.content
else:
text = ''
images = []
for message_content in first_prompt_message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
images.append(image_data)
data['prompt'] = text
data['images'] = images
# send a post request to validate the credentials
response = requests.post(
endpoint_url,
headers=headers,
json=data,
timeout=(10, 60),
stream=stream
)
response.encoding = "utf-8"
if response.status_code != 200:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm completion response
:param model: model name
:param credentials: model credentials
:param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm result
"""
response_json = response.json()
if completion_type is LLMMode.CHAT:
message = response_json.get('message', {})
response_content = message.get('content', '')
else:
response_content = response_json['response']
assistant_message = AssistantPromptMessage(content=response_content)
if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
# transform usage
prompt_tokens = response_json["prompt_eval_count"]
completion_tokens = response_json["eval_count"]
else:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response_json["model"],
prompt_messages=prompt_messages,
message=assistant_message,
usage=usage,
)
return result
def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm completion stream response
:param model: model name
:param credentials: model credentials
:param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
full_text = ''
chunk_index = 0
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
-> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=message,
finish_reason=finish_reason,
usage=usage
)
)
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
if not chunk:
continue
try:
chunk_json = json.loads(chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
chunk_index += 1
break
if completion_type is LLMMode.CHAT:
if not chunk_json:
continue
if 'message' not in chunk_json:
text = ''
else:
text = chunk_json.get('message').get('content', '')
else:
if not chunk_json:
continue
# transform assistant message to prompt message
text = chunk_json['response']
assistant_prompt_message = AssistantPromptMessage(
content=text
)
full_text += text
if chunk_json['done']:
# calculate num tokens
if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
# transform usage
prompt_tokens = chunk_json["prompt_eval_count"]
completion_tokens = chunk_json["eval_count"]
else:
# calculate num tokens
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk_json['model'],
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
finish_reason='stop',
usage=usage
)
)
else:
yield LLMResultChunk(
model=chunk_json['model'],
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
)
chunk_index += 1
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict for Ollama API
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
text = ''
images = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
text = message_content.data
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
images.append(image_data)
message_dict = {"role": "user", "content": text, "images": images}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int:
"""
Calculate num tokens.
:param messages: messages
"""
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
for key, value in message.items():
num_tokens += self._get_num_tokens_by_gpt2(str(key))
num_tokens += self._get_num_tokens_by_gpt2(str(value))
return num_tokens
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
Get customizable model schema.
:param model: model name
:param credentials: credentials
:return: model schema
"""
extras = {}
if 'vision_support' in credentials and credentials['vision_support'] == 'true':
extras['features'] = [ModelFeature.VISION]
entity = AIModelEntity(
model=model,
label=I18nObject(
zh_Hans=model,
en_US=model
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: credentials.get('mode'),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
use_template=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="The temperature of the model. "
"Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)"),
default=0.8,
min=0,
max=2
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
use_template=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
"more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)"),
default=0.9,
min=0,
max=1
),
ParameterRule(
name="top_k",
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
"A higher value (e.g. 100) will give more diverse answers, "
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
default=40,
min=1,
max=100
),
ParameterRule(
name='repeat_penalty',
label=I18nObject(en_US="Repeat Penalty"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
default=1.1,
min=-2,
max=2
),
ParameterRule(
name='num_predict',
use_template='max_tokens',
label=I18nObject(en_US="Num Predict"),
type=ParameterType.INT,
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
default=128,
min=-2,
max=int(credentials.get('max_tokens', 4096)),
),
ParameterRule(
name='mirostat',
label=I18nObject(en_US="Mirostat sampling"),
type=ParameterType.INT,
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
default=0,
min=0,
max=2
),
ParameterRule(
name='mirostat_eta',
label=I18nObject(en_US="Mirostat Eta"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
"the generated text. A lower learning rate will result in slower adjustments, "
"while a higher learning rate will make the algorithm more responsive. "
"(Default: 0.1)"),
default=0.1,
precision=1
),
ParameterRule(
name='mirostat_tau',
label=I18nObject(en_US="Mirostat Tau"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
default=5.0,
precision=1
),
ParameterRule(
name='num_ctx',
label=I18nObject(en_US="Size of context window"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
"(Default: 2048)"),
default=2048,
min=1
),
ParameterRule(
name='num_gpu',
label=I18nObject(en_US="Num GPU"),
type=ParameterType.INT,
help=I18nObject(en_US="The number of layers to send to the GPU(s). "
"On macOS it defaults to 1 to enable metal support, 0 to disable."),
default=1,
min=0,
max=1
),
ParameterRule(
name='num_thread',
label=I18nObject(en_US="Num Thread"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets the number of threads to use during computation. "
"By default, Ollama will detect this for optimal performance. "
"It is recommended to set this value to the number of physical CPU cores "
"your system has (as opposed to the logical number of cores)."),
min=1,
),
ParameterRule(
name='repeat_last_n',
label=I18nObject(en_US="Repeat last N"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
default=64,
min=-1
),
ParameterRule(
name='tfs_z',
label=I18nObject(en_US="TFS Z"),
type=ParameterType.FLOAT,
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
"while a value of 1.0 disables this setting. (default: 1)"),
default=1,
precision=1
),
ParameterRule(
name='seed',
label=I18nObject(en_US="Seed"),
type=ParameterType.INT,
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
"a specific number will make the model generate the same text for "
"the same prompt. (Default: 0)"),
default=0
),
ParameterRule(
name='format',
label=I18nObject(en_US="Format"),
type=ParameterType.STRING,
help=I18nObject(en_US="the format to return a response in."
" Currently the only accepted value is json."),
options=['json'],
)
],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
output=Decimal(credentials.get('output_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
),
**extras
)
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeAuthorizationError: [
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
],
InvokeBadRequestError: [
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
requests.exceptions.InvalidURL, # Misconfigured request or other API error
],
InvokeRateLimitError: [
requests.exceptions.RetryError # Too many requests sent in a short period of time
],
InvokeServerUnavailableError: [
requests.exceptions.ConnectionError, # Engine Overloaded
requests.exceptions.HTTPError # Server Error
],
InvokeConnectionError: [
requests.exceptions.ConnectTimeout, # Timeout
requests.exceptions.ReadTimeout # Timeout
]
}
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass
provider: ollama
label:
en_US: Ollama
icon_large:
en_US: icon_l_en.svg
icon_small:
en_US: icon_s_en.svg
background: "#F9FAFB"
help:
title:
en_US: How to integrate with Ollama
zh_Hans: 如何集成 Ollama
url:
en_US: https://docs.dify.ai/advanced/model-configuration/ollama
supported_model_types:
- llm
- text-embedding
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: base_url
label:
zh_Hans: 基础 URL
en_US: Base URL
type: text-input
required: true
placeholder:
zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: 模型类型
en_US: Completion mode
type: select
required: true
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: '4096'
type: text-input
required: true
- variable: vision_support
label:
zh_Hans: 是否支持 Vision
en_US: Vision support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: Yes
zh_Hans:
- value: 'false'
label:
en_US: No
zh_Hans:
import logging
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import requests
import json
import numpy as np
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \
PriceConfig
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
logger = logging.getLogger(__name__)
class OllamaEmbeddingModel(TextEmbeddingModel):
"""
Model class for an Ollama text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# Prepare headers and payload for the request
headers = {
'Content-Type': 'application/json'
}
endpoint_url = credentials.get('base_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = urljoin(endpoint_url, 'api/embeddings')
# get model properties
context_size = self._get_context_size(model, credentials)
inputs = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0: cutoff])
else:
inputs.append(text)
batched_embeddings = []
for text in inputs:
# Prepare the payload for the request
payload = {
'prompt': text,
'model': model,
}
# Make the request to the OpenAI API
response = requests.post(
endpoint_url,
headers=headers,
data=json.dumps(payload),
timeout=(10, 300)
)
response.raise_for_status() # Raise an exception for HTTP errors
response_data = response.json()
# Extract embeddings and used tokens from the response
embeddings = response_data['embedding']
embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
used_tokens += embedding_used_tokens
batched_embeddings.append(embeddings)
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
texts=['ping']
)
except InvokeError as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
return entity
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeAuthorizationError: [
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
],
InvokeBadRequestError: [
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
requests.exceptions.InvalidURL, # Misconfigured request or other API error
],
InvokeRateLimitError: [
requests.exceptions.RetryError # Too many requests sent in a short period of time
],
InvokeServerUnavailableError: [
requests.exceptions.ConnectionError, # Engine Overloaded
requests.exceptions.HTTPError # Server Error
],
InvokeConnectionError: [
requests.exceptions.ConnectTimeout, # Timeout
requests.exceptions.ReadTimeout # Timeout
]
}
......@@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
continue
......
......@@ -33,8 +33,8 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Base URL, eg. https://api.openai.com/v1
zh_Hans: Base URL, e.g. https://api.openai.com/v1
en_US: Base URL, e.g. https://api.openai.com/v1
- variable: mode
show_on:
- variable: __model_type
......
......@@ -33,5 +33,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入OpenLLM的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your OpenLLM, for example https://example.com/xxx
zh_Hans: 在此输入OpenLLM的服务器地址,如 http://192.168.1.100:3000
en_US: Enter the url of your OpenLLM, e.g. http://192.168.1.100:3000
......@@ -34,8 +34,8 @@ model_credential_schema:
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
- variable: model_uid
label:
zh_Hans: 模型UID
......
......@@ -121,6 +121,7 @@ class PromptTransform:
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
......@@ -341,6 +342,13 @@ class PromptTransform:
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
prompt_message = UserPromptMessage(content=prompt)
......@@ -434,6 +442,7 @@ class PromptTransform:
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: List[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> List[PromptMessage]:
......@@ -461,6 +470,13 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=prompt))
return prompt_messages
......
......@@ -62,5 +62,8 @@ COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
# Ollama Credentials
OLLAMA_BASE_URL=
# Mock Switch
MOCK_SWITCH=false
\ No newline at end of file
import os
from typing import Generator
import pytest
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
SystemPromptMessage, TextPromptMessageContent, ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
LLMResultChunk
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.ollama.llm.llm import OllamaLargeLanguageModel
def test_validate_credentials():
model = OllamaLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': 'http://localhost:21434',
'mode': 'chat',
'context_size': 2048,
'max_tokens': 2048,
}
)
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 2048,
'max_tokens': 2048,
}
)
def test_invoke_model():
model = OllamaLargeLanguageModel()
response = model.invoke(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 2048,
'max_tokens': 2048,
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
'num_predict': 10
},
stop=['How'],
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = OllamaLargeLanguageModel()
response = model.invoke(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 2048,
'max_tokens': 2048,
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
'num_predict': 10
},
stop=['How'],
stream=True
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
def test_invoke_completion_model():
model = OllamaLargeLanguageModel()
response = model.invoke(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'completion',
'context_size': 2048,
'max_tokens': 2048,
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
'num_predict': 10
},
stop=['How'],
stream=False
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_completion_model():
model = OllamaLargeLanguageModel()
response = model.invoke(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'completion',
'context_size': 2048,
'max_tokens': 2048,
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
'num_predict': 10
},
stop=['How'],
stream=True
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
def test_invoke_completion_model_with_vision():
model = OllamaLargeLanguageModel()
result = model.invoke(
model='llava',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'completion',
'context_size': 2048,
'max_tokens': 2048,
},
prompt_messages=[
UserPromptMessage(
content=[
TextPromptMessageContent(
data='What is this in this picture?',
),
ImagePromptMessageContent(
data=''
)
]
)
],
model_parameters={
'temperature': 0.1,
'num_predict': 100
},
stream=False,
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
def test_invoke_chat_model_with_vision():
model = OllamaLargeLanguageModel()
result = model.invoke(
model='llava',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 2048,
'max_tokens': 2048,
},
prompt_messages=[
UserPromptMessage(
content=[
TextPromptMessageContent(
data='What is this in this picture?',
),
ImagePromptMessageContent(
data=''
)
]
)
],
model_parameters={
'temperature': 0.1,
'num_predict': 100
},
stream=False,
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
def test_get_num_tokens():
model = OllamaLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 2048,
'max_tokens': 2048,
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
)
assert isinstance(num_tokens, int)
assert num_tokens == 6
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
def test_validate_credentials():
model = OllamaEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': 'http://localhost:21434',
'mode': 'chat',
'context_size': 4096,
}
)
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
}
)
def test_invoke_model():
model = OllamaEmbeddingModel()
result = model.invoke(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = OllamaEmbeddingModel()
num_tokens = model.get_num_tokens(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment