Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
cca9edc9
Unverified
Commit
cca9edc9
authored
Jan 12, 2024
by
takatost
Committed by
GitHub
Jan 12, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: ollama support (#2003)
parent
5e75f702
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
21 changed files
with
1369 additions
and
13 deletions
+1369
-13
generate_task_pipeline.py
api/core/app_runner/generate_task_pipeline.py
+26
-3
_position.yaml
api/core/model_runtime/model_providers/_position.yaml
+1
-0
localai.yaml
api/core/model_runtime/model_providers/localai/localai.yaml
+2
-2
__init__.py
api/core/model_runtime/model_providers/ollama/__init__.py
+0
-0
icon_l_en.svg
...odel_runtime/model_providers/ollama/_assets/icon_l_en.svg
+15
-0
icon_s_en.svg
...odel_runtime/model_providers/ollama/_assets/icon_s_en.svg
+15
-0
__init__.py
...core/model_runtime/model_providers/ollama/llm/__init__.py
+0
-0
llm.py
api/core/model_runtime/model_providers/ollama/llm/llm.py
+615
-0
ollama.py
api/core/model_runtime/model_providers/ollama/ollama.py
+17
-0
ollama.yaml
api/core/model_runtime/model_providers/ollama/ollama.yaml
+98
-0
__init__.py
...runtime/model_providers/ollama/text_embedding/__init__.py
+0
-0
text_embedding.py
...e/model_providers/ollama/text_embedding/text_embedding.py
+221
-0
llm.py
..._runtime/model_providers/openai_api_compatible/llm/llm.py
+1
-0
openai_api_compatible.yaml
...roviders/openai_api_compatible/openai_api_compatible.yaml
+2
-2
openllm.yaml
api/core/model_runtime/model_providers/openllm/openllm.yaml
+2
-2
xinference.yaml
.../model_runtime/model_providers/xinference/xinference.yaml
+2
-2
prompt_transform.py
api/core/prompt/prompt_transform.py
+18
-2
.env.example
api/tests/integration_tests/.env.example
+3
-0
__init__.py
api/tests/integration_tests/model_runtime/ollama/__init__.py
+0
-0
test_llm.py
api/tests/integration_tests/model_runtime/ollama/test_llm.py
+260
-0
test_text_embedding.py
...gration_tests/model_runtime/ollama/test_text_embedding.py
+71
-0
No files found.
api/core/app_runner/generate_task_pipeline.py
View file @
cca9edc9
...
...
@@ -459,10 +459,33 @@ class GenerateTaskPipeline:
"files"
:
files
})
else
:
prompts
.
append
({
prompt_message
=
prompt_messages
[
0
]
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
params
=
{
"role"
:
'user'
,
"text"
:
prompt_messages
[
0
]
.
content
})
"text"
:
text
,
}
if
files
:
params
[
'files'
]
=
files
prompts
.
append
(
params
)
return
prompts
...
...
api/core/model_runtime/model_providers/_position.yaml
View file @
cca9edc9
...
...
@@ -6,6 +6,7 @@
-
huggingface_hub
-
cohere
-
togetherai
-
ollama
-
zhipuai
-
baichuan
-
spark
...
...
api/core/model_runtime/model_providers/localai/localai.yaml
View file @
cca9edc9
...
...
@@ -54,5 +54,5 @@ model_credential_schema:
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
在此输入LocalAI的服务器地址,如 http
s://example.com/xxx
en_US
:
Enter the url of your LocalAI,
for example https://example.com/xxx
zh_Hans
:
在此输入LocalAI的服务器地址,如 http
://192.168.1.100:8080
en_US
:
Enter the url of your LocalAI,
e.g. http://192.168.1.100:8080
api/core/model_runtime/model_providers/ollama/__init__.py
0 → 100644
View file @
cca9edc9
api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg
0 → 100644
View file @
cca9edc9
This diff is collapsed.
Click to expand it.
api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg
0 → 100644
View file @
cca9edc9
<svg
width=
"24"
height=
"24"
viewBox=
"0 0 24 24"
fill=
"none"
xmlns=
"http://www.w3.org/2000/svg"
xmlns:xlink=
"http://www.w3.org/1999/xlink"
>
<g
clip-path=
"url(#clip0_16325_59237)"
>
<rect
width=
"24"
height=
"24"
rx=
"5"
fill=
"white"
/>
<rect
x=
"3.5"
width=
"17"
height=
"24"
fill=
"url(#pattern0)"
/>
</g>
<defs>
<pattern
id=
"pattern0"
patternContentUnits=
"objectBoundingBox"
width=
"1"
height=
"1"
>
<use
xlink:href=
"#image0_16325_59237"
transform=
"matrix(0.00552486 0 0 0.00391344 0 -0.00092081)"
/>
</pattern>
<clipPath
id=
"clip0_16325_59237"
>
<rect
width=
"24"
height=
"24"
fill=
"white"
/>
</clipPath>
<image
id=
"image0_16325_59237"
width=
"181"
height=
"256"
xlink:href=
""
/>
</defs>
</svg>
api/core/model_runtime/model_providers/ollama/llm/__init__.py
0 → 100644
View file @
cca9edc9
api/core/model_runtime/model_providers/ollama/llm/llm.py
0 → 100644
View file @
cca9edc9
This diff is collapsed.
Click to expand it.
api/core/model_runtime/model_providers/ollama/ollama.py
0 → 100644
View file @
cca9edc9
import
logging
from
core.model_runtime.model_providers.__base.model_provider
import
ModelProvider
logger
=
logging
.
getLogger
(
__name__
)
class
OpenAIProvider
(
ModelProvider
):
def
validate_provider_credentials
(
self
,
credentials
:
dict
)
->
None
:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass
api/core/model_runtime/model_providers/ollama/ollama.yaml
0 → 100644
View file @
cca9edc9
provider
:
ollama
label
:
en_US
:
Ollama
icon_large
:
en_US
:
icon_l_en.svg
icon_small
:
en_US
:
icon_s_en.svg
background
:
"
#F9FAFB"
help
:
title
:
en_US
:
How to integrate with Ollama
zh_Hans
:
如何集成 Ollama
url
:
en_US
:
https://docs.dify.ai/advanced/model-configuration/ollama
supported_model_types
:
-
llm
-
text-embedding
configurate_methods
:
-
customizable-model
model_credential_schema
:
model
:
label
:
en_US
:
Model Name
zh_Hans
:
模型名称
placeholder
:
en_US
:
Enter your model name
zh_Hans
:
输入模型名称
credential_form_schemas
:
-
variable
:
base_url
label
:
zh_Hans
:
基础 URL
en_US
:
Base URL
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
Ollama server 的基础 URL,例如 http://192.168.1.100:11434
en_US
:
Base url of Ollama server, e.g. http://192.168.1.100:11434
-
variable
:
mode
show_on
:
-
variable
:
__model_type
value
:
llm
label
:
zh_Hans
:
模型类型
en_US
:
Completion mode
type
:
select
required
:
true
default
:
chat
placeholder
:
zh_Hans
:
选择对话类型
en_US
:
Select completion mode
options
:
-
value
:
completion
label
:
en_US
:
Completion
zh_Hans
:
补全
-
value
:
chat
label
:
en_US
:
Chat
zh_Hans
:
对话
-
variable
:
context_size
label
:
zh_Hans
:
模型上下文长度
en_US
:
Model context size
required
:
true
type
:
text-input
default
:
'
4096'
placeholder
:
zh_Hans
:
在此输入您的模型上下文长度
en_US
:
Enter your Model context size
-
variable
:
max_tokens
label
:
zh_Hans
:
最大 token 上限
en_US
:
Upper bound for max tokens
show_on
:
-
variable
:
__model_type
value
:
llm
default
:
'
4096'
type
:
text-input
required
:
true
-
variable
:
vision_support
label
:
zh_Hans
:
是否支持 Vision
en_US
:
Vision support
show_on
:
-
variable
:
__model_type
value
:
llm
default
:
'
false'
type
:
radio
required
:
false
options
:
-
value
:
'
true'
label
:
en_US
:
Yes
zh_Hans
:
是
-
value
:
'
false'
label
:
en_US
:
No
zh_Hans
:
否
api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py
0 → 100644
View file @
cca9edc9
api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
0 → 100644
View file @
cca9edc9
import
logging
import
time
from
decimal
import
Decimal
from
typing
import
Optional
from
urllib.parse
import
urljoin
import
requests
import
json
import
numpy
as
np
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.model_entities
import
PriceType
,
ModelPropertyKey
,
ModelType
,
AIModelEntity
,
FetchFrom
,
\
PriceConfig
from
core.model_runtime.entities.text_embedding_entities
import
TextEmbeddingResult
,
EmbeddingUsage
from
core.model_runtime.errors.invoke
import
InvokeError
,
InvokeAuthorizationError
,
InvokeBadRequestError
,
\
InvokeRateLimitError
,
InvokeServerUnavailableError
,
InvokeConnectionError
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.text_embedding_model
import
TextEmbeddingModel
logger
=
logging
.
getLogger
(
__name__
)
class
OllamaEmbeddingModel
(
TextEmbeddingModel
):
"""
Model class for an Ollama text embedding model.
"""
def
_invoke
(
self
,
model
:
str
,
credentials
:
dict
,
texts
:
list
[
str
],
user
:
Optional
[
str
]
=
None
)
\
->
TextEmbeddingResult
:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# Prepare headers and payload for the request
headers
=
{
'Content-Type'
:
'application/json'
}
endpoint_url
=
credentials
.
get
(
'base_url'
)
if
not
endpoint_url
.
endswith
(
'/'
):
endpoint_url
+=
'/'
endpoint_url
=
urljoin
(
endpoint_url
,
'api/embeddings'
)
# get model properties
context_size
=
self
.
_get_context_size
(
model
,
credentials
)
inputs
=
[]
used_tokens
=
0
for
i
,
text
in
enumerate
(
texts
):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens
=
self
.
_get_num_tokens_by_gpt2
(
text
)
if
num_tokens
>=
context_size
:
cutoff
=
int
(
len
(
text
)
*
(
np
.
floor
(
context_size
/
num_tokens
)))
# if num tokens is larger than context length, only use the start
inputs
.
append
(
text
[
0
:
cutoff
])
else
:
inputs
.
append
(
text
)
batched_embeddings
=
[]
for
text
in
inputs
:
# Prepare the payload for the request
payload
=
{
'prompt'
:
text
,
'model'
:
model
,
}
# Make the request to the OpenAI API
response
=
requests
.
post
(
endpoint_url
,
headers
=
headers
,
data
=
json
.
dumps
(
payload
),
timeout
=
(
10
,
300
)
)
response
.
raise_for_status
()
# Raise an exception for HTTP errors
response_data
=
response
.
json
()
# Extract embeddings and used tokens from the response
embeddings
=
response_data
[
'embedding'
]
embedding_used_tokens
=
self
.
get_num_tokens
(
model
,
credentials
,
[
text
])
used_tokens
+=
embedding_used_tokens
batched_embeddings
.
append
(
embeddings
)
# calc usage
usage
=
self
.
_calc_response_usage
(
model
=
model
,
credentials
=
credentials
,
tokens
=
used_tokens
)
return
TextEmbeddingResult
(
embeddings
=
batched_embeddings
,
usage
=
usage
,
model
=
model
)
def
get_num_tokens
(
self
,
model
:
str
,
credentials
:
dict
,
texts
:
list
[
str
])
->
int
:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return
sum
(
self
.
_get_num_tokens_by_gpt2
(
text
)
for
text
in
texts
)
def
validate_credentials
(
self
,
model
:
str
,
credentials
:
dict
)
->
None
:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try
:
self
.
_invoke
(
model
=
model
,
credentials
=
credentials
,
texts
=
[
'ping'
]
)
except
InvokeError
as
ex
:
raise
CredentialsValidateFailedError
(
f
'An error occurred during credentials validation: {ex.description}'
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
f
'An error occurred during credentials validation: {str(ex)}'
)
def
get_customizable_model_schema
(
self
,
model
:
str
,
credentials
:
dict
)
->
AIModelEntity
:
"""
generate custom model entities from credentials
"""
entity
=
AIModelEntity
(
model
=
model
,
label
=
I18nObject
(
en_US
=
model
),
model_type
=
ModelType
.
TEXT_EMBEDDING
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
ModelPropertyKey
.
CONTEXT_SIZE
:
int
(
credentials
.
get
(
'context_size'
)),
ModelPropertyKey
.
MAX_CHUNKS
:
1
,
},
parameter_rules
=
[],
pricing
=
PriceConfig
(
input
=
Decimal
(
credentials
.
get
(
'input_price'
,
0
)),
unit
=
Decimal
(
credentials
.
get
(
'unit'
,
0
)),
currency
=
credentials
.
get
(
'currency'
,
"USD"
)
)
)
return
entity
def
_calc_response_usage
(
self
,
model
:
str
,
credentials
:
dict
,
tokens
:
int
)
->
EmbeddingUsage
:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info
=
self
.
get_price
(
model
=
model
,
credentials
=
credentials
,
price_type
=
PriceType
.
INPUT
,
tokens
=
tokens
)
# transform usage
usage
=
EmbeddingUsage
(
tokens
=
tokens
,
total_tokens
=
tokens
,
unit_price
=
input_price_info
.
unit_price
,
price_unit
=
input_price_info
.
unit
,
total_price
=
input_price_info
.
total_amount
,
currency
=
input_price_info
.
currency
,
latency
=
time
.
perf_counter
()
-
self
.
started_at
)
return
usage
@
property
def
_invoke_error_mapping
(
self
)
->
dict
[
type
[
InvokeError
],
list
[
type
[
Exception
]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return
{
InvokeAuthorizationError
:
[
requests
.
exceptions
.
InvalidHeader
,
# Missing or Invalid API Key
],
InvokeBadRequestError
:
[
requests
.
exceptions
.
HTTPError
,
# Invalid Endpoint URL or model name
requests
.
exceptions
.
InvalidURL
,
# Misconfigured request or other API error
],
InvokeRateLimitError
:
[
requests
.
exceptions
.
RetryError
# Too many requests sent in a short period of time
],
InvokeServerUnavailableError
:
[
requests
.
exceptions
.
ConnectionError
,
# Engine Overloaded
requests
.
exceptions
.
HTTPError
# Server Error
],
InvokeConnectionError
:
[
requests
.
exceptions
.
ConnectTimeout
,
# Timeout
requests
.
exceptions
.
ReadTimeout
# Timeout
]
}
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
View file @
cca9edc9
...
...
@@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message
=
AssistantPromptMessage
(
content
=
""
),
finish_reason
=
"Non-JSON encountered."
)
break
if
not
chunk_json
or
len
(
chunk_json
[
'choices'
])
==
0
:
continue
...
...
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
View file @
cca9edc9
...
...
@@ -33,8 +33,8 @@ model_credential_schema:
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
Base URL, eg. https://api.openai.com/v1
en_US
:
Base URL, eg. https://api.openai.com/v1
zh_Hans
:
Base URL, e
.
g. https://api.openai.com/v1
en_US
:
Base URL, e
.
g. https://api.openai.com/v1
-
variable
:
mode
show_on
:
-
variable
:
__model_type
...
...
api/core/model_runtime/model_providers/openllm/openllm.yaml
View file @
cca9edc9
...
...
@@ -33,5 +33,5 @@ model_credential_schema:
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
在此输入OpenLLM的服务器地址,如 http
s://example.com/xxx
en_US
:
Enter the url of your OpenLLM,
for example https://example.com/xxx
zh_Hans
:
在此输入OpenLLM的服务器地址,如 http
://192.168.1.100:3000
en_US
:
Enter the url of your OpenLLM,
e.g. http://192.168.1.100:3000
api/core/model_runtime/model_providers/xinference/xinference.yaml
View file @
cca9edc9
...
...
@@ -34,8 +34,8 @@ model_credential_schema:
type
:
secret-input
required
:
true
placeholder
:
zh_Hans
:
在此输入Xinference的服务器地址,如 http
s://example.com/xxx
en_US
:
Enter the url of your Xinference,
for example https://example.com/xxx
zh_Hans
:
在此输入Xinference的服务器地址,如 http
://192.168.1.100:9997
en_US
:
Enter the url of your Xinference,
e.g. http://192.168.1.100:9997
-
variable
:
model_uid
label
:
zh_Hans
:
模型UID
...
...
api/core/prompt/prompt_transform.py
View file @
cca9edc9
...
...
@@ -121,6 +121,7 @@ class PromptTransform:
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
query
=
query
,
files
=
files
,
context
=
context
,
memory
=
memory
,
model_config
=
model_config
...
...
@@ -343,7 +344,14 @@ class PromptTransform:
prompt_message
=
UserPromptMessage
(
content
=
prompt_message_contents
)
else
:
prompt_message
=
UserPromptMessage
(
content
=
prompt
)
if
files
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
prompt
)]
for
file
in
files
:
prompt_message_contents
.
append
(
file
.
prompt_message_content
)
prompt_message
=
UserPromptMessage
(
content
=
prompt_message_contents
)
else
:
prompt_message
=
UserPromptMessage
(
content
=
prompt
)
return
[
prompt_message
]
...
...
@@ -434,6 +442,7 @@ class PromptTransform:
prompt_template_entity
:
PromptTemplateEntity
,
inputs
:
dict
,
query
:
str
,
files
:
List
[
FileObj
],
context
:
Optional
[
str
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigEntity
)
->
List
[
PromptMessage
]:
...
...
@@ -461,7 +470,14 @@ class PromptTransform:
prompt
=
self
.
_format_prompt
(
prompt_template
,
prompt_inputs
)
prompt_messages
.
append
(
UserPromptMessage
(
content
=
prompt
))
if
files
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
prompt
)]
for
file
in
files
:
prompt_message_contents
.
append
(
file
.
prompt_message_content
)
prompt_messages
.
append
(
UserPromptMessage
(
content
=
prompt_message_contents
))
else
:
prompt_messages
.
append
(
UserPromptMessage
(
content
=
prompt
))
return
prompt_messages
...
...
api/tests/integration_tests/.env.example
View file @
cca9edc9
...
...
@@ -62,5 +62,8 @@ COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
# Ollama Credentials
OLLAMA_BASE_URL=
# Mock Switch
MOCK_SWITCH=false
\ No newline at end of file
api/tests/integration_tests/model_runtime/ollama/__init__.py
0 → 100644
View file @
cca9edc9
api/tests/integration_tests/model_runtime/ollama/test_llm.py
0 → 100644
View file @
cca9edc9
This diff is collapsed.
Click to expand it.
api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py
0 → 100644
View file @
cca9edc9
import
os
import
pytest
from
core.model_runtime.entities.text_embedding_entities
import
TextEmbeddingResult
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.ollama.text_embedding.text_embedding
import
OllamaEmbeddingModel
def
test_validate_credentials
():
model
=
OllamaEmbeddingModel
()
with
pytest
.
raises
(
CredentialsValidateFailedError
):
model
.
validate_credentials
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
'http://localhost:21434'
,
'mode'
:
'chat'
,
'context_size'
:
4096
,
}
)
model
.
validate_credentials
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
4096
,
}
)
def
test_invoke_model
():
model
=
OllamaEmbeddingModel
()
result
=
model
.
invoke
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
4096
,
},
texts
=
[
"hello"
,
"world"
],
user
=
"abc-123"
)
assert
isinstance
(
result
,
TextEmbeddingResult
)
assert
len
(
result
.
embeddings
)
==
2
assert
result
.
usage
.
total_tokens
==
2
def
test_get_num_tokens
():
model
=
OllamaEmbeddingModel
()
num_tokens
=
model
.
get_num_tokens
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
4096
,
},
texts
=
[
"hello"
,
"world"
]
)
assert
num_tokens
==
2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment