Unverified Commit cca9edc9 authored by takatost's avatar takatost Committed by GitHub

feat: ollama support (#2003)

parent 5e75f702
......@@ -459,10 +459,33 @@ class GenerateTaskPipeline:
"files": files
})
else:
prompts.append({
prompt_message = prompt_messages[0]
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
params = {
"role": 'user',
"text": prompt_messages[0].content
})
"text": text,
}
if files:
params['files'] = files
prompts.append(params)
return prompts
......
......@@ -6,6 +6,7 @@
- huggingface_hub
- cohere
- togetherai
- ollama
- zhipuai
- baichuan
- spark
......
......@@ -54,5 +54,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入LocalAI的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your LocalAI, for example https://example.com/xxx
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g clip-path="url(#clip0_16325_59237)">
<rect width="24" height="24" rx="5" fill="white"/>
<rect x="3.5" width="17" height="24" fill="url(#pattern0)"/>
</g>
<defs>
<pattern id="pattern0" patternContentUnits="objectBoundingBox" width="1" height="1">
<use xlink:href="#image0_16325_59237" transform="matrix(0.00552486 0 0 0.00391344 0 -0.00092081)"/>
</pattern>
<clipPath id="clip0_16325_59237">
<rect width="24" height="24" fill="white"/>
</clipPath>
<image id="image0_16325_59237" width="181" height="256" xlink:href=""/>
</defs>
</svg>
This diff is collapsed.
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class OpenAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass
provider: ollama
label:
en_US: Ollama
icon_large:
en_US: icon_l_en.svg
icon_small:
en_US: icon_s_en.svg
background: "#F9FAFB"
help:
title:
en_US: How to integrate with Ollama
zh_Hans: 如何集成 Ollama
url:
en_US: https://docs.dify.ai/advanced/model-configuration/ollama
supported_model_types:
- llm
- text-embedding
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: base_url
label:
zh_Hans: 基础 URL
en_US: Base URL
type: text-input
required: true
placeholder:
zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: 模型类型
en_US: Completion mode
type: select
required: true
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: '4096'
type: text-input
required: true
- variable: vision_support
label:
zh_Hans: 是否支持 Vision
en_US: Vision support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: Yes
zh_Hans:
- value: 'false'
label:
en_US: No
zh_Hans:
import logging
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import requests
import json
import numpy as np
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import PriceType, ModelPropertyKey, ModelType, AIModelEntity, FetchFrom, \
PriceConfig
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError, InvokeBadRequestError, \
InvokeRateLimitError, InvokeServerUnavailableError, InvokeConnectionError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
logger = logging.getLogger(__name__)
class OllamaEmbeddingModel(TextEmbeddingModel):
"""
Model class for an Ollama text embedding model.
"""
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# Prepare headers and payload for the request
headers = {
'Content-Type': 'application/json'
}
endpoint_url = credentials.get('base_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = urljoin(endpoint_url, 'api/embeddings')
# get model properties
context_size = self._get_context_size(model, credentials)
inputs = []
used_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0: cutoff])
else:
inputs.append(text)
batched_embeddings = []
for text in inputs:
# Prepare the payload for the request
payload = {
'prompt': text,
'model': model,
}
# Make the request to the OpenAI API
response = requests.post(
endpoint_url,
headers=headers,
data=json.dumps(payload),
timeout=(10, 300)
)
response.raise_for_status() # Raise an exception for HTTP errors
response_data = response.json()
# Extract embeddings and used tokens from the response
embeddings = response_data['embedding']
embedding_used_tokens = self.get_num_tokens(model, credentials, [text])
used_tokens += embedding_used_tokens
batched_embeddings.append(embeddings)
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)
return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
model=model
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
texts=['ping']
)
except InvokeError as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],
pricing=PriceConfig(
input=Decimal(credentials.get('input_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
)
return entity
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeAuthorizationError: [
requests.exceptions.InvalidHeader, # Missing or Invalid API Key
],
InvokeBadRequestError: [
requests.exceptions.HTTPError, # Invalid Endpoint URL or model name
requests.exceptions.InvalidURL, # Misconfigured request or other API error
],
InvokeRateLimitError: [
requests.exceptions.RetryError # Too many requests sent in a short period of time
],
InvokeServerUnavailableError: [
requests.exceptions.ConnectionError, # Engine Overloaded
requests.exceptions.HTTPError # Server Error
],
InvokeConnectionError: [
requests.exceptions.ConnectTimeout, # Timeout
requests.exceptions.ReadTimeout # Timeout
]
}
......@@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
continue
......
......@@ -33,8 +33,8 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Base URL, eg. https://api.openai.com/v1
zh_Hans: Base URL, e.g. https://api.openai.com/v1
en_US: Base URL, e.g. https://api.openai.com/v1
- variable: mode
show_on:
- variable: __model_type
......
......@@ -33,5 +33,5 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入OpenLLM的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your OpenLLM, for example https://example.com/xxx
zh_Hans: 在此输入OpenLLM的服务器地址,如 http://192.168.1.100:3000
en_US: Enter the url of your OpenLLM, e.g. http://192.168.1.100:3000
......@@ -34,8 +34,8 @@ model_credential_schema:
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx
zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
- variable: model_uid
label:
zh_Hans: 模型UID
......
......@@ -121,6 +121,7 @@ class PromptTransform:
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
......@@ -343,7 +344,14 @@ class PromptTransform:
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
prompt_message = UserPromptMessage(content=prompt)
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
prompt_message = UserPromptMessage(content=prompt)
return [prompt_message]
......@@ -434,6 +442,7 @@ class PromptTransform:
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: List[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> List[PromptMessage]:
......@@ -461,7 +470,14 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(UserPromptMessage(content=prompt))
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=prompt))
return prompt_messages
......
......@@ -62,5 +62,8 @@ COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
# Ollama Credentials
OLLAMA_BASE_URL=
# Mock Switch
MOCK_SWITCH=false
\ No newline at end of file
This diff is collapsed.
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
def test_validate_credentials():
model = OllamaEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': 'http://localhost:21434',
'mode': 'chat',
'context_size': 4096,
}
)
model.validate_credentials(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
}
)
def test_invoke_model():
model = OllamaEmbeddingModel()
result = model.invoke(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
},
texts=[
"hello",
"world"
],
user="abc-123"
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
def test_get_num_tokens():
model = OllamaEmbeddingModel()
num_tokens = model.get_num_tokens(
model='mistral:text',
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
},
texts=[
"hello",
"world"
]
)
assert num_tokens == 2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment