Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
cca9edc9
Unverified
Commit
cca9edc9
authored
Jan 12, 2024
by
takatost
Committed by
GitHub
Jan 12, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: ollama support (#2003)
parent
5e75f702
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
21 changed files
with
1369 additions
and
13 deletions
+1369
-13
generate_task_pipeline.py
api/core/app_runner/generate_task_pipeline.py
+26
-3
_position.yaml
api/core/model_runtime/model_providers/_position.yaml
+1
-0
localai.yaml
api/core/model_runtime/model_providers/localai/localai.yaml
+2
-2
__init__.py
api/core/model_runtime/model_providers/ollama/__init__.py
+0
-0
icon_l_en.svg
...odel_runtime/model_providers/ollama/_assets/icon_l_en.svg
+15
-0
icon_s_en.svg
...odel_runtime/model_providers/ollama/_assets/icon_s_en.svg
+15
-0
__init__.py
...core/model_runtime/model_providers/ollama/llm/__init__.py
+0
-0
llm.py
api/core/model_runtime/model_providers/ollama/llm/llm.py
+615
-0
ollama.py
api/core/model_runtime/model_providers/ollama/ollama.py
+17
-0
ollama.yaml
api/core/model_runtime/model_providers/ollama/ollama.yaml
+98
-0
__init__.py
...runtime/model_providers/ollama/text_embedding/__init__.py
+0
-0
text_embedding.py
...e/model_providers/ollama/text_embedding/text_embedding.py
+221
-0
llm.py
..._runtime/model_providers/openai_api_compatible/llm/llm.py
+1
-0
openai_api_compatible.yaml
...roviders/openai_api_compatible/openai_api_compatible.yaml
+2
-2
openllm.yaml
api/core/model_runtime/model_providers/openllm/openllm.yaml
+2
-2
xinference.yaml
.../model_runtime/model_providers/xinference/xinference.yaml
+2
-2
prompt_transform.py
api/core/prompt/prompt_transform.py
+18
-2
.env.example
api/tests/integration_tests/.env.example
+3
-0
__init__.py
api/tests/integration_tests/model_runtime/ollama/__init__.py
+0
-0
test_llm.py
api/tests/integration_tests/model_runtime/ollama/test_llm.py
+260
-0
test_text_embedding.py
...gration_tests/model_runtime/ollama/test_text_embedding.py
+71
-0
No files found.
api/core/app_runner/generate_task_pipeline.py
View file @
cca9edc9
...
...
@@ -459,10 +459,33 @@ class GenerateTaskPipeline:
"files"
:
files
})
else
:
prompts
.
append
({
prompt_message
=
prompt_messages
[
0
]
text
=
''
files
=
[]
if
isinstance
(
prompt_message
.
content
,
list
):
for
content
in
prompt_message
.
content
:
if
content
.
type
==
PromptMessageContentType
.
TEXT
:
content
=
cast
(
TextPromptMessageContent
,
content
)
text
+=
content
.
data
else
:
content
=
cast
(
ImagePromptMessageContent
,
content
)
files
.
append
({
"type"
:
'image'
,
"data"
:
content
.
data
[:
10
]
+
'...[TRUNCATED]...'
+
content
.
data
[
-
10
:],
"detail"
:
content
.
detail
.
value
})
else
:
text
=
prompt_message
.
content
params
=
{
"role"
:
'user'
,
"text"
:
prompt_messages
[
0
]
.
content
})
"text"
:
text
,
}
if
files
:
params
[
'files'
]
=
files
prompts
.
append
(
params
)
return
prompts
...
...
api/core/model_runtime/model_providers/_position.yaml
View file @
cca9edc9
...
...
@@ -6,6 +6,7 @@
-
huggingface_hub
-
cohere
-
togetherai
-
ollama
-
zhipuai
-
baichuan
-
spark
...
...
api/core/model_runtime/model_providers/localai/localai.yaml
View file @
cca9edc9
...
...
@@ -54,5 +54,5 @@ model_credential_schema:
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
在此输入LocalAI的服务器地址,如 http
s://example.com/xxx
en_US
:
Enter the url of your LocalAI,
for example https://example.com/xxx
zh_Hans
:
在此输入LocalAI的服务器地址,如 http
://192.168.1.100:8080
en_US
:
Enter the url of your LocalAI,
e.g. http://192.168.1.100:8080
api/core/model_runtime/model_providers/ollama/__init__.py
0 → 100644
View file @
cca9edc9
api/core/model_runtime/model_providers/ollama/_assets/icon_l_en.svg
0 → 100644
View file @
cca9edc9
<svg
width=
"82"
height=
"24"
viewBox=
"0 0 82 24"
fill=
"none"
xmlns=
"http://www.w3.org/2000/svg"
xmlns:xlink=
"http://www.w3.org/1999/xlink"
>
<rect
x=
"1"
width=
"16.9688"
height=
"24"
fill=
"url(#pattern0)"
/>
<path
d=
"M71.4453 14.552C71.4453 13.6667 71.6266 12.8827 71.9893 12.2C72.3626 11.5174 72.864 10.9894 73.4933 10.616C74.1333 10.232 74.8373 10.04 75.6053 10.04C76.2986 10.04 76.9013 10.1787 77.4133 10.456C77.936 10.7227 78.352 11.0587 78.6613 11.464V10.184H80.5013V19H78.6613V17.688C78.352 18.104 77.9306 18.4507 77.3973 18.728C76.864 19.0054 76.256 19.144 75.5733 19.144C74.816 19.144 74.1226 18.952 73.4933 18.568C72.864 18.1734 72.3626 17.6294 71.9893 16.936C71.6266 16.232 71.4453 15.4374 71.4453 14.552ZM78.6613 14.584C78.6613 13.976 78.5333 13.448 78.2773 13C78.032 12.552 77.7066 12.2107 77.3013 11.976C76.896 11.7414 76.4586 11.624 75.9893 11.624C75.52 11.624 75.0826 11.7414 74.6773 11.976C74.272 12.2 73.9413 12.536 73.6853 12.984C73.44 13.4214 73.3173 13.944 73.3173 14.552C73.3173 15.16 73.44 15.6934 73.6853 16.152C73.9413 16.6107 74.272 16.9627 74.6773 17.208C75.0933 17.4427 75.5306 17.56 75.9893 17.56C76.4586 17.56 76.896 17.4427 77.3013 17.208C77.7066 16.9734 78.032 16.632 78.2773 16.184C78.5333 15.7254 78.6613 15.192 78.6613 14.584Z"
fill=
"black"
/>
<path
d=
"M66.42 10.04C67.1134 10.04 67.732 10.184 68.276 10.472C68.8307 10.76 69.2627 11.1867 69.572 11.752C69.892 12.3174 70.052 13 70.052 13.8V19H68.244V14.072C68.244 13.2827 68.0467 12.68 67.652 12.264C67.2574 11.8374 66.7187 11.624 66.036 11.624C65.3534 11.624 64.8094 11.8374 64.404 12.264C64.0094 12.68 63.812 13.2827 63.812 14.072V19H62.004V14.072C62.004 13.2827 61.8067 12.68 61.412 12.264C61.0174 11.8374 60.4787 11.624 59.796 11.624C59.1134 11.624 58.5694 11.8374 58.164 12.264C57.7694 12.68 57.572 13.2827 57.572 14.072V19H55.748V10.184H57.572V11.192C57.8707 10.8294 58.2494 10.5467 58.708 10.344C59.1667 10.1414 59.6574 10.04 60.18 10.04C60.884 10.04 61.5134 10.1894 62.068 10.488C62.6227 10.7867 63.0494 11.2187 63.348 11.784C63.6147 11.2507 64.0307 10.8294 64.596 10.52C65.1614 10.2 65.7694 10.04 66.42 10.04Z"
fill=
"black"
/>
<path
d=
"M44.6152 14.552C44.6152 13.6667 44.7966 12.8827 45.1592 12.2C45.5326 11.5174 46.0339 10.9894 46.6632 10.616C47.3032 10.232 48.0072 10.04 48.7752 10.04C49.4686 10.04 50.0712 10.1787 50.5832 10.456C51.1059 10.7227 51.5219 11.0587 51.8312 11.464V10.184H53.6712V19H51.8312V17.688C51.5219 18.104 51.1006 18.4507 50.5672 18.728C50.0339 19.0054 49.4259 19.144 48.7432 19.144C47.9859 19.144 47.2926 18.952 46.6632 18.568C46.0339 18.1734 45.5326 17.6294 45.1592 16.936C44.7966 16.232 44.6152 15.4374 44.6152 14.552ZM51.8312 14.584C51.8312 13.976 51.7032 13.448 51.4472 13C51.2019 12.552 50.8766 12.2107 50.4712 11.976C50.0659 11.7414 49.6286 11.624 49.1592 11.624C48.6899 11.624 48.2526 11.7414 47.8472 11.976C47.4419 12.2 47.1112 12.536 46.8552 12.984C46.6099 13.4214 46.4872 13.944 46.4872 14.552C46.4872 15.16 46.6099 15.6934 46.8552 16.152C47.1112 16.6107 47.4419 16.9627 47.8472 17.208C48.2632 17.4427 48.7006 17.56 49.1592 17.56C49.6286 17.56 50.0659 17.4427 50.4712 17.208C50.8766 16.9734 51.2019 16.632 51.4472 16.184C51.7032 15.7254 51.8312 15.192 51.8312 14.584Z"
fill=
"black"
/>
<path
d=
"M43.1502 7.16016V19.0002H41.3262V7.16016H43.1502Z"
fill=
"black"
/>
<path
d=
"M39.2498 7.16016V19.0002H37.4258V7.16016H39.2498Z"
fill=
"black"
/>
<path
d=
"M30.2718 19.1123C29.2371 19.1123 28.2825 18.8723 27.4078 18.3923C26.5438 17.9017 25.8558 17.2243 25.3438 16.3603C24.8425 15.4857 24.5918 14.5043 24.5918 13.4163C24.5918 12.3283 24.8425 11.3523 25.3438 10.4883C25.8558 9.62433 26.5438 8.95233 27.4078 8.47233C28.2825 7.98166 29.2371 7.73633 30.2718 7.73633C31.3171 7.73633 32.2718 7.98166 33.1358 8.47233C34.0105 8.95233 34.6985 9.62433 35.1998 10.4883C35.7011 11.3523 35.9518 12.3283 35.9518 13.4163C35.9518 14.5043 35.7011 15.4857 35.1998 16.3603C34.6985 17.2243 34.0105 17.9017 33.1358 18.3923C32.2718 18.8723 31.3171 19.1123 30.2718 19.1123ZM30.2718 17.5283C31.0078 17.5283 31.6638 17.363 32.2398 17.0323C32.8158 16.691 33.2638 16.211 33.5838 15.5923C33.9145 14.963 34.0798 14.2377 34.0798 13.4163C34.0798 12.595 33.9145 11.875 33.5838 11.2563C33.2638 10.6377 32.8158 10.163 32.2398 9.83233C31.6638 9.50166 31.0078 9.33633 30.2718 9.33633C29.5358 9.33633 28.8798 9.50166 28.3038 9.83233C27.7278 10.163 27.2745 10.6377 26.9438 11.2563C26.6238 11.875 26.4638 12.595 26.4638 13.4163C26.4638 14.2377 26.6238 14.963 26.9438 15.5923C27.2745 16.211 27.7278 16.691 28.3038 17.0323C28.8798 17.363 29.5358 17.5283 30.2718 17.5283Z"
fill=
"black"
/>
<defs>
<pattern
id=
"pattern0"
patternContentUnits=
"objectBoundingBox"
width=
"1"
height=
"1"
>
<use
xlink:href=
"#image0_16324_59298"
transform=
"scale(0.00552486 0.00390625)"
/>
</pattern>
<image
id=
"image0_16324_59298"
width=
"181"
height=
"256"
xlink:href=
""
/>
</defs>
</svg>
api/core/model_runtime/model_providers/ollama/_assets/icon_s_en.svg
0 → 100644
View file @
cca9edc9
<svg
width=
"24"
height=
"24"
viewBox=
"0 0 24 24"
fill=
"none"
xmlns=
"http://www.w3.org/2000/svg"
xmlns:xlink=
"http://www.w3.org/1999/xlink"
>
<g
clip-path=
"url(#clip0_16325_59237)"
>
<rect
width=
"24"
height=
"24"
rx=
"5"
fill=
"white"
/>
<rect
x=
"3.5"
width=
"17"
height=
"24"
fill=
"url(#pattern0)"
/>
</g>
<defs>
<pattern
id=
"pattern0"
patternContentUnits=
"objectBoundingBox"
width=
"1"
height=
"1"
>
<use
xlink:href=
"#image0_16325_59237"
transform=
"matrix(0.00552486 0 0 0.00391344 0 -0.00092081)"
/>
</pattern>
<clipPath
id=
"clip0_16325_59237"
>
<rect
width=
"24"
height=
"24"
fill=
"white"
/>
</clipPath>
<image
id=
"image0_16325_59237"
width=
"181"
height=
"256"
xlink:href=
""
/>
</defs>
</svg>
api/core/model_runtime/model_providers/ollama/llm/__init__.py
0 → 100644
View file @
cca9edc9
api/core/model_runtime/model_providers/ollama/llm/llm.py
0 → 100644
View file @
cca9edc9
import
json
import
logging
import
re
from
decimal
import
Decimal
from
typing
import
Optional
,
Generator
,
Union
,
List
,
cast
from
urllib.parse
import
urljoin
import
requests
from
core.model_runtime.entities.message_entities
import
PromptMessageTool
,
PromptMessage
,
AssistantPromptMessage
,
\
UserPromptMessage
,
PromptMessageContentType
,
ImagePromptMessageContent
,
\
TextPromptMessageContent
,
SystemPromptMessage
from
core.model_runtime.entities.model_entities
import
I18nObject
,
ModelType
,
\
PriceConfig
,
AIModelEntity
,
FetchFrom
,
ModelPropertyKey
,
ParameterRule
,
ParameterType
,
DefaultParameterName
,
\
ModelFeature
from
core.model_runtime.entities.llm_entities
import
LLMMode
,
LLMResult
,
\
LLMResultChunk
,
LLMResultChunkDelta
from
core.model_runtime.errors.invoke
import
InvokeError
,
InvokeAuthorizationError
,
InvokeBadRequestError
,
\
InvokeRateLimitError
,
InvokeServerUnavailableError
,
InvokeConnectionError
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
logger
=
logging
.
getLogger
(
__name__
)
class
OllamaLargeLanguageModel
(
LargeLanguageModel
):
"""
Model class for Ollama large language model.
"""
def
_invoke
(
self
,
model
:
str
,
credentials
:
dict
,
prompt_messages
:
list
[
PromptMessage
],
model_parameters
:
dict
,
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stream
:
bool
=
True
,
user
:
Optional
[
str
]
=
None
)
\
->
Union
[
LLMResult
,
Generator
]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
return
self
.
_generate
(
model
=
model
,
credentials
=
credentials
,
prompt_messages
=
prompt_messages
,
model_parameters
=
model_parameters
,
stop
=
stop
,
stream
=
stream
,
user
=
user
)
def
get_num_tokens
(
self
,
model
:
str
,
credentials
:
dict
,
prompt_messages
:
list
[
PromptMessage
],
tools
:
Optional
[
list
[
PromptMessageTool
]]
=
None
)
->
int
:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# get model mode
model_mode
=
self
.
get_model_mode
(
model
,
credentials
)
if
model_mode
==
LLMMode
.
CHAT
:
# chat model
return
self
.
_num_tokens_from_messages
(
prompt_messages
)
else
:
first_prompt_message
=
prompt_messages
[
0
]
if
isinstance
(
first_prompt_message
.
content
,
str
):
text
=
first_prompt_message
.
content
else
:
text
=
''
for
message_content
in
first_prompt_message
.
content
:
if
message_content
.
type
==
PromptMessageContentType
.
TEXT
:
message_content
=
cast
(
TextPromptMessageContent
,
message_content
)
text
=
message_content
.
data
break
return
self
.
_get_num_tokens_by_gpt2
(
text
)
def
validate_credentials
(
self
,
model
:
str
,
credentials
:
dict
)
->
None
:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try
:
self
.
_generate
(
model
=
model
,
credentials
=
credentials
,
prompt_messages
=
[
UserPromptMessage
(
content
=
"ping"
)],
model_parameters
=
{
'num_predict'
:
5
},
stream
=
False
)
except
InvokeError
as
ex
:
raise
CredentialsValidateFailedError
(
f
'An error occurred during credentials validation: {ex.description}'
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
f
'An error occurred during credentials validation: {str(ex)}'
)
def
_generate
(
self
,
model
:
str
,
credentials
:
dict
,
prompt_messages
:
list
[
PromptMessage
],
model_parameters
:
dict
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stream
:
bool
=
True
,
user
:
Optional
[
str
]
=
None
)
->
Union
[
LLMResult
,
Generator
]:
"""
Invoke llm completion model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
headers
=
{
'Content-Type'
:
'application/json'
}
endpoint_url
=
credentials
[
'base_url'
]
if
not
endpoint_url
.
endswith
(
'/'
):
endpoint_url
+=
'/'
# prepare the payload for a simple ping to the model
data
=
{
'model'
:
model
,
'stream'
:
stream
}
if
'format'
in
model_parameters
:
data
[
'format'
]
=
model_parameters
[
'format'
]
del
model_parameters
[
'format'
]
data
[
'options'
]
=
model_parameters
or
{}
if
stop
:
data
[
'stop'
]
=
"
\n
"
.
join
(
stop
)
completion_type
=
LLMMode
.
value_of
(
credentials
[
'mode'
])
if
completion_type
is
LLMMode
.
CHAT
:
endpoint_url
=
urljoin
(
endpoint_url
,
'api/chat'
)
data
[
'messages'
]
=
[
self
.
_convert_prompt_message_to_dict
(
m
)
for
m
in
prompt_messages
]
else
:
endpoint_url
=
urljoin
(
endpoint_url
,
'api/generate'
)
first_prompt_message
=
prompt_messages
[
0
]
if
isinstance
(
first_prompt_message
,
UserPromptMessage
):
first_prompt_message
=
cast
(
UserPromptMessage
,
first_prompt_message
)
if
isinstance
(
first_prompt_message
.
content
,
str
):
data
[
'prompt'
]
=
first_prompt_message
.
content
else
:
text
=
''
images
=
[]
for
message_content
in
first_prompt_message
.
content
:
if
message_content
.
type
==
PromptMessageContentType
.
TEXT
:
message_content
=
cast
(
TextPromptMessageContent
,
message_content
)
text
=
message_content
.
data
elif
message_content
.
type
==
PromptMessageContentType
.
IMAGE
:
message_content
=
cast
(
ImagePromptMessageContent
,
message_content
)
image_data
=
re
.
sub
(
r'^data:image\/[a-zA-Z]+;base64,'
,
''
,
message_content
.
data
)
images
.
append
(
image_data
)
data
[
'prompt'
]
=
text
data
[
'images'
]
=
images
# send a post request to validate the credentials
response
=
requests
.
post
(
endpoint_url
,
headers
=
headers
,
json
=
data
,
timeout
=
(
10
,
60
),
stream
=
stream
)
response
.
encoding
=
"utf-8"
if
response
.
status_code
!=
200
:
raise
InvokeError
(
f
"API request failed with status code {response.status_code}: {response.text}"
)
if
stream
:
return
self
.
_handle_generate_stream_response
(
model
,
credentials
,
completion_type
,
response
,
prompt_messages
)
return
self
.
_handle_generate_response
(
model
,
credentials
,
completion_type
,
response
,
prompt_messages
)
def
_handle_generate_response
(
self
,
model
:
str
,
credentials
:
dict
,
completion_type
:
LLMMode
,
response
:
requests
.
Response
,
prompt_messages
:
list
[
PromptMessage
])
->
LLMResult
:
"""
Handle llm completion response
:param model: model name
:param credentials: model credentials
:param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm result
"""
response_json
=
response
.
json
()
if
completion_type
is
LLMMode
.
CHAT
:
message
=
response_json
.
get
(
'message'
,
{})
response_content
=
message
.
get
(
'content'
,
''
)
else
:
response_content
=
response_json
[
'response'
]
assistant_message
=
AssistantPromptMessage
(
content
=
response_content
)
if
'prompt_eval_count'
in
response_json
and
'eval_count'
in
response_json
:
# transform usage
prompt_tokens
=
response_json
[
"prompt_eval_count"
]
completion_tokens
=
response_json
[
"eval_count"
]
else
:
# calculate num tokens
prompt_tokens
=
self
.
_get_num_tokens_by_gpt2
(
prompt_messages
[
0
]
.
content
)
completion_tokens
=
self
.
_get_num_tokens_by_gpt2
(
assistant_message
.
content
)
# transform usage
usage
=
self
.
_calc_response_usage
(
model
,
credentials
,
prompt_tokens
,
completion_tokens
)
# transform response
result
=
LLMResult
(
model
=
response_json
[
"model"
],
prompt_messages
=
prompt_messages
,
message
=
assistant_message
,
usage
=
usage
,
)
return
result
def
_handle_generate_stream_response
(
self
,
model
:
str
,
credentials
:
dict
,
completion_type
:
LLMMode
,
response
:
requests
.
Response
,
prompt_messages
:
list
[
PromptMessage
])
->
Generator
:
"""
Handle llm completion stream response
:param model: model name
:param credentials: model credentials
:param completion_type: completion type
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
full_text
=
''
chunk_index
=
0
def
create_final_llm_result_chunk
(
index
:
int
,
message
:
AssistantPromptMessage
,
finish_reason
:
str
)
\
->
LLMResultChunk
:
# calculate num tokens
prompt_tokens
=
self
.
_get_num_tokens_by_gpt2
(
prompt_messages
[
0
]
.
content
)
completion_tokens
=
self
.
_get_num_tokens_by_gpt2
(
full_text
)
# transform usage
usage
=
self
.
_calc_response_usage
(
model
,
credentials
,
prompt_tokens
,
completion_tokens
)
return
LLMResultChunk
(
model
=
model
,
prompt_messages
=
prompt_messages
,
delta
=
LLMResultChunkDelta
(
index
=
index
,
message
=
message
,
finish_reason
=
finish_reason
,
usage
=
usage
)
)
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
True
,
delimiter
=
'
\n
'
):
if
not
chunk
:
continue
try
:
chunk_json
=
json
.
loads
(
chunk
)
# stream ended
except
json
.
JSONDecodeError
as
e
:
yield
create_final_llm_result_chunk
(
index
=
chunk_index
,
message
=
AssistantPromptMessage
(
content
=
""
),
finish_reason
=
"Non-JSON encountered."
)
chunk_index
+=
1
break
if
completion_type
is
LLMMode
.
CHAT
:
if
not
chunk_json
:
continue
if
'message'
not
in
chunk_json
:
text
=
''
else
:
text
=
chunk_json
.
get
(
'message'
)
.
get
(
'content'
,
''
)
else
:
if
not
chunk_json
:
continue
# transform assistant message to prompt message
text
=
chunk_json
[
'response'
]
assistant_prompt_message
=
AssistantPromptMessage
(
content
=
text
)
full_text
+=
text
if
chunk_json
[
'done'
]:
# calculate num tokens
if
'prompt_eval_count'
in
chunk_json
and
'eval_count'
in
chunk_json
:
# transform usage
prompt_tokens
=
chunk_json
[
"prompt_eval_count"
]
completion_tokens
=
chunk_json
[
"eval_count"
]
else
:
# calculate num tokens
prompt_tokens
=
self
.
_get_num_tokens_by_gpt2
(
prompt_messages
[
0
]
.
content
)
completion_tokens
=
self
.
_get_num_tokens_by_gpt2
(
full_text
)
# transform usage
usage
=
self
.
_calc_response_usage
(
model
,
credentials
,
prompt_tokens
,
completion_tokens
)
yield
LLMResultChunk
(
model
=
chunk_json
[
'model'
],
prompt_messages
=
prompt_messages
,
delta
=
LLMResultChunkDelta
(
index
=
chunk_index
,
message
=
assistant_prompt_message
,
finish_reason
=
'stop'
,
usage
=
usage
)
)
else
:
yield
LLMResultChunk
(
model
=
chunk_json
[
'model'
],
prompt_messages
=
prompt_messages
,
delta
=
LLMResultChunkDelta
(
index
=
chunk_index
,
message
=
assistant_prompt_message
,
)
)
chunk_index
+=
1
def
_convert_prompt_message_to_dict
(
self
,
message
:
PromptMessage
)
->
dict
:
"""
Convert PromptMessage to dict for Ollama API
"""
if
isinstance
(
message
,
UserPromptMessage
):
message
=
cast
(
UserPromptMessage
,
message
)
if
isinstance
(
message
.
content
,
str
):
message_dict
=
{
"role"
:
"user"
,
"content"
:
message
.
content
}
else
:
text
=
''
images
=
[]
for
message_content
in
message
.
content
:
if
message_content
.
type
==
PromptMessageContentType
.
TEXT
:
message_content
=
cast
(
TextPromptMessageContent
,
message_content
)
text
=
message_content
.
data
elif
message_content
.
type
==
PromptMessageContentType
.
IMAGE
:
message_content
=
cast
(
ImagePromptMessageContent
,
message_content
)
image_data
=
re
.
sub
(
r'^data:image\/[a-zA-Z]+;base64,'
,
''
,
message_content
.
data
)
images
.
append
(
image_data
)
message_dict
=
{
"role"
:
"user"
,
"content"
:
text
,
"images"
:
images
}
elif
isinstance
(
message
,
AssistantPromptMessage
):
message
=
cast
(
AssistantPromptMessage
,
message
)
message_dict
=
{
"role"
:
"assistant"
,
"content"
:
message
.
content
}
elif
isinstance
(
message
,
SystemPromptMessage
):
message
=
cast
(
SystemPromptMessage
,
message
)
message_dict
=
{
"role"
:
"system"
,
"content"
:
message
.
content
}
else
:
raise
ValueError
(
f
"Got unknown type {message}"
)
return
message_dict
def
_num_tokens_from_messages
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
Calculate num tokens.
:param messages: messages
"""
num_tokens
=
0
messages_dict
=
[
self
.
_convert_prompt_message_to_dict
(
m
)
for
m
in
messages
]
for
message
in
messages_dict
:
for
key
,
value
in
message
.
items
():
num_tokens
+=
self
.
_get_num_tokens_by_gpt2
(
str
(
key
))
num_tokens
+=
self
.
_get_num_tokens_by_gpt2
(
str
(
value
))
return
num_tokens
def
get_customizable_model_schema
(
self
,
model
:
str
,
credentials
:
dict
)
->
AIModelEntity
:
"""
Get customizable model schema.
:param model: model name
:param credentials: credentials
:return: model schema
"""
extras
=
{}
if
'vision_support'
in
credentials
and
credentials
[
'vision_support'
]
==
'true'
:
extras
[
'features'
]
=
[
ModelFeature
.
VISION
]
entity
=
AIModelEntity
(
model
=
model
,
label
=
I18nObject
(
zh_Hans
=
model
,
en_US
=
model
),
model_type
=
ModelType
.
LLM
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
ModelPropertyKey
.
MODE
:
credentials
.
get
(
'mode'
),
ModelPropertyKey
.
CONTEXT_SIZE
:
int
(
credentials
.
get
(
'context_size'
,
4096
)),
},
parameter_rules
=
[
ParameterRule
(
name
=
DefaultParameterName
.
TEMPERATURE
.
value
,
use_template
=
DefaultParameterName
.
TEMPERATURE
.
value
,
label
=
I18nObject
(
en_US
=
"Temperature"
),
type
=
ParameterType
.
FLOAT
,
help
=
I18nObject
(
en_US
=
"The temperature of the model. "
"Increasing the temperature will make the model answer "
"more creatively. (Default: 0.8)"
),
default
=
0.8
,
min
=
0
,
max
=
2
),
ParameterRule
(
name
=
DefaultParameterName
.
TOP_P
.
value
,
use_template
=
DefaultParameterName
.
TOP_P
.
value
,
label
=
I18nObject
(
en_US
=
"Top P"
),
type
=
ParameterType
.
FLOAT
,
help
=
I18nObject
(
en_US
=
"Works together with top-k. A higher value (e.g., 0.95) will lead to "
"more diverse text, while a lower value (e.g., 0.5) will generate more "
"focused and conservative text. (Default: 0.9)"
),
default
=
0.9
,
min
=
0
,
max
=
1
),
ParameterRule
(
name
=
"top_k"
,
label
=
I18nObject
(
en_US
=
"Top K"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"Reduces the probability of generating nonsense. "
"A higher value (e.g. 100) will give more diverse answers, "
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
),
default
=
40
,
min
=
1
,
max
=
100
),
ParameterRule
(
name
=
'repeat_penalty'
,
label
=
I18nObject
(
en_US
=
"Repeat Penalty"
),
type
=
ParameterType
.
FLOAT
,
help
=
I18nObject
(
en_US
=
"Sets how strongly to penalize repetitions. "
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
),
default
=
1.1
,
min
=-
2
,
max
=
2
),
ParameterRule
(
name
=
'num_predict'
,
use_template
=
'max_tokens'
,
label
=
I18nObject
(
en_US
=
"Num Predict"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"Maximum number of tokens to predict when generating text. "
"(Default: 128, -1 = infinite generation, -2 = fill context)"
),
default
=
128
,
min
=-
2
,
max
=
int
(
credentials
.
get
(
'max_tokens'
,
4096
)),
),
ParameterRule
(
name
=
'mirostat'
,
label
=
I18nObject
(
en_US
=
"Mirostat sampling"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"Enable Mirostat sampling for controlling perplexity. "
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
),
default
=
0
,
min
=
0
,
max
=
2
),
ParameterRule
(
name
=
'mirostat_eta'
,
label
=
I18nObject
(
en_US
=
"Mirostat Eta"
),
type
=
ParameterType
.
FLOAT
,
help
=
I18nObject
(
en_US
=
"Influences how quickly the algorithm responds to feedback from "
"the generated text. A lower learning rate will result in slower adjustments, "
"while a higher learning rate will make the algorithm more responsive. "
"(Default: 0.1)"
),
default
=
0.1
,
precision
=
1
),
ParameterRule
(
name
=
'mirostat_tau'
,
label
=
I18nObject
(
en_US
=
"Mirostat Tau"
),
type
=
ParameterType
.
FLOAT
,
help
=
I18nObject
(
en_US
=
"Controls the balance between coherence and diversity of the output. "
"A lower value will result in more focused and coherent text. (Default: 5.0)"
),
default
=
5.0
,
precision
=
1
),
ParameterRule
(
name
=
'num_ctx'
,
label
=
I18nObject
(
en_US
=
"Size of context window"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"Sets the size of the context window used to generate the next token. "
"(Default: 2048)"
),
default
=
2048
,
min
=
1
),
ParameterRule
(
name
=
'num_gpu'
,
label
=
I18nObject
(
en_US
=
"Num GPU"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"The number of layers to send to the GPU(s). "
"On macOS it defaults to 1 to enable metal support, 0 to disable."
),
default
=
1
,
min
=
0
,
max
=
1
),
ParameterRule
(
name
=
'num_thread'
,
label
=
I18nObject
(
en_US
=
"Num Thread"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"Sets the number of threads to use during computation. "
"By default, Ollama will detect this for optimal performance. "
"It is recommended to set this value to the number of physical CPU cores "
"your system has (as opposed to the logical number of cores)."
),
min
=
1
,
),
ParameterRule
(
name
=
'repeat_last_n'
,
label
=
I18nObject
(
en_US
=
"Repeat last N"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"Sets how far back for the model to look back to prevent repetition. "
"(Default: 64, 0 = disabled, -1 = num_ctx)"
),
default
=
64
,
min
=-
1
),
ParameterRule
(
name
=
'tfs_z'
,
label
=
I18nObject
(
en_US
=
"TFS Z"
),
type
=
ParameterType
.
FLOAT
,
help
=
I18nObject
(
en_US
=
"Tail free sampling is used to reduce the impact of less probable tokens "
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
"while a value of 1.0 disables this setting. (default: 1)"
),
default
=
1
,
precision
=
1
),
ParameterRule
(
name
=
'seed'
,
label
=
I18nObject
(
en_US
=
"Seed"
),
type
=
ParameterType
.
INT
,
help
=
I18nObject
(
en_US
=
"Sets the random number seed to use for generation. Setting this to "
"a specific number will make the model generate the same text for "
"the same prompt. (Default: 0)"
),
default
=
0
),
ParameterRule
(
name
=
'format'
,
label
=
I18nObject
(
en_US
=
"Format"
),
type
=
ParameterType
.
STRING
,
help
=
I18nObject
(
en_US
=
"the format to return a response in."
" Currently the only accepted value is json."
),
options
=
[
'json'
],
)
],
pricing
=
PriceConfig
(
input
=
Decimal
(
credentials
.
get
(
'input_price'
,
0
)),
output
=
Decimal
(
credentials
.
get
(
'output_price'
,
0
)),
unit
=
Decimal
(
credentials
.
get
(
'unit'
,
0
)),
currency
=
credentials
.
get
(
'currency'
,
"USD"
)
),
**
extras
)
return
entity
@
property
def
_invoke_error_mapping
(
self
)
->
dict
[
type
[
InvokeError
],
list
[
type
[
Exception
]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return
{
InvokeAuthorizationError
:
[
requests
.
exceptions
.
InvalidHeader
,
# Missing or Invalid API Key
],
InvokeBadRequestError
:
[
requests
.
exceptions
.
HTTPError
,
# Invalid Endpoint URL or model name
requests
.
exceptions
.
InvalidURL
,
# Misconfigured request or other API error
],
InvokeRateLimitError
:
[
requests
.
exceptions
.
RetryError
# Too many requests sent in a short period of time
],
InvokeServerUnavailableError
:
[
requests
.
exceptions
.
ConnectionError
,
# Engine Overloaded
requests
.
exceptions
.
HTTPError
# Server Error
],
InvokeConnectionError
:
[
requests
.
exceptions
.
ConnectTimeout
,
# Timeout
requests
.
exceptions
.
ReadTimeout
# Timeout
]
}
api/core/model_runtime/model_providers/ollama/ollama.py
0 → 100644
View file @
cca9edc9
import
logging
from
core.model_runtime.model_providers.__base.model_provider
import
ModelProvider
logger
=
logging
.
getLogger
(
__name__
)
class
OpenAIProvider
(
ModelProvider
):
def
validate_provider_credentials
(
self
,
credentials
:
dict
)
->
None
:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass
api/core/model_runtime/model_providers/ollama/ollama.yaml
0 → 100644
View file @
cca9edc9
provider
:
ollama
label
:
en_US
:
Ollama
icon_large
:
en_US
:
icon_l_en.svg
icon_small
:
en_US
:
icon_s_en.svg
background
:
"
#F9FAFB"
help
:
title
:
en_US
:
How to integrate with Ollama
zh_Hans
:
如何集成 Ollama
url
:
en_US
:
https://docs.dify.ai/advanced/model-configuration/ollama
supported_model_types
:
-
llm
-
text-embedding
configurate_methods
:
-
customizable-model
model_credential_schema
:
model
:
label
:
en_US
:
Model Name
zh_Hans
:
模型名称
placeholder
:
en_US
:
Enter your model name
zh_Hans
:
输入模型名称
credential_form_schemas
:
-
variable
:
base_url
label
:
zh_Hans
:
基础 URL
en_US
:
Base URL
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
Ollama server 的基础 URL,例如 http://192.168.1.100:11434
en_US
:
Base url of Ollama server, e.g. http://192.168.1.100:11434
-
variable
:
mode
show_on
:
-
variable
:
__model_type
value
:
llm
label
:
zh_Hans
:
模型类型
en_US
:
Completion mode
type
:
select
required
:
true
default
:
chat
placeholder
:
zh_Hans
:
选择对话类型
en_US
:
Select completion mode
options
:
-
value
:
completion
label
:
en_US
:
Completion
zh_Hans
:
补全
-
value
:
chat
label
:
en_US
:
Chat
zh_Hans
:
对话
-
variable
:
context_size
label
:
zh_Hans
:
模型上下文长度
en_US
:
Model context size
required
:
true
type
:
text-input
default
:
'
4096'
placeholder
:
zh_Hans
:
在此输入您的模型上下文长度
en_US
:
Enter your Model context size
-
variable
:
max_tokens
label
:
zh_Hans
:
最大 token 上限
en_US
:
Upper bound for max tokens
show_on
:
-
variable
:
__model_type
value
:
llm
default
:
'
4096'
type
:
text-input
required
:
true
-
variable
:
vision_support
label
:
zh_Hans
:
是否支持 Vision
en_US
:
Vision support
show_on
:
-
variable
:
__model_type
value
:
llm
default
:
'
false'
type
:
radio
required
:
false
options
:
-
value
:
'
true'
label
:
en_US
:
Yes
zh_Hans
:
是
-
value
:
'
false'
label
:
en_US
:
No
zh_Hans
:
否
api/core/model_runtime/model_providers/ollama/text_embedding/__init__.py
0 → 100644
View file @
cca9edc9
api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
0 → 100644
View file @
cca9edc9
import
logging
import
time
from
decimal
import
Decimal
from
typing
import
Optional
from
urllib.parse
import
urljoin
import
requests
import
json
import
numpy
as
np
from
core.model_runtime.entities.common_entities
import
I18nObject
from
core.model_runtime.entities.model_entities
import
PriceType
,
ModelPropertyKey
,
ModelType
,
AIModelEntity
,
FetchFrom
,
\
PriceConfig
from
core.model_runtime.entities.text_embedding_entities
import
TextEmbeddingResult
,
EmbeddingUsage
from
core.model_runtime.errors.invoke
import
InvokeError
,
InvokeAuthorizationError
,
InvokeBadRequestError
,
\
InvokeRateLimitError
,
InvokeServerUnavailableError
,
InvokeConnectionError
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.__base.text_embedding_model
import
TextEmbeddingModel
logger
=
logging
.
getLogger
(
__name__
)
class
OllamaEmbeddingModel
(
TextEmbeddingModel
):
"""
Model class for an Ollama text embedding model.
"""
def
_invoke
(
self
,
model
:
str
,
credentials
:
dict
,
texts
:
list
[
str
],
user
:
Optional
[
str
]
=
None
)
\
->
TextEmbeddingResult
:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# Prepare headers and payload for the request
headers
=
{
'Content-Type'
:
'application/json'
}
endpoint_url
=
credentials
.
get
(
'base_url'
)
if
not
endpoint_url
.
endswith
(
'/'
):
endpoint_url
+=
'/'
endpoint_url
=
urljoin
(
endpoint_url
,
'api/embeddings'
)
# get model properties
context_size
=
self
.
_get_context_size
(
model
,
credentials
)
inputs
=
[]
used_tokens
=
0
for
i
,
text
in
enumerate
(
texts
):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens
=
self
.
_get_num_tokens_by_gpt2
(
text
)
if
num_tokens
>=
context_size
:
cutoff
=
int
(
len
(
text
)
*
(
np
.
floor
(
context_size
/
num_tokens
)))
# if num tokens is larger than context length, only use the start
inputs
.
append
(
text
[
0
:
cutoff
])
else
:
inputs
.
append
(
text
)
batched_embeddings
=
[]
for
text
in
inputs
:
# Prepare the payload for the request
payload
=
{
'prompt'
:
text
,
'model'
:
model
,
}
# Make the request to the OpenAI API
response
=
requests
.
post
(
endpoint_url
,
headers
=
headers
,
data
=
json
.
dumps
(
payload
),
timeout
=
(
10
,
300
)
)
response
.
raise_for_status
()
# Raise an exception for HTTP errors
response_data
=
response
.
json
()
# Extract embeddings and used tokens from the response
embeddings
=
response_data
[
'embedding'
]
embedding_used_tokens
=
self
.
get_num_tokens
(
model
,
credentials
,
[
text
])
used_tokens
+=
embedding_used_tokens
batched_embeddings
.
append
(
embeddings
)
# calc usage
usage
=
self
.
_calc_response_usage
(
model
=
model
,
credentials
=
credentials
,
tokens
=
used_tokens
)
return
TextEmbeddingResult
(
embeddings
=
batched_embeddings
,
usage
=
usage
,
model
=
model
)
def
get_num_tokens
(
self
,
model
:
str
,
credentials
:
dict
,
texts
:
list
[
str
])
->
int
:
"""
Approximate number of tokens for given messages using GPT2 tokenizer
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return
sum
(
self
.
_get_num_tokens_by_gpt2
(
text
)
for
text
in
texts
)
def
validate_credentials
(
self
,
model
:
str
,
credentials
:
dict
)
->
None
:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try
:
self
.
_invoke
(
model
=
model
,
credentials
=
credentials
,
texts
=
[
'ping'
]
)
except
InvokeError
as
ex
:
raise
CredentialsValidateFailedError
(
f
'An error occurred during credentials validation: {ex.description}'
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
f
'An error occurred during credentials validation: {str(ex)}'
)
def
get_customizable_model_schema
(
self
,
model
:
str
,
credentials
:
dict
)
->
AIModelEntity
:
"""
generate custom model entities from credentials
"""
entity
=
AIModelEntity
(
model
=
model
,
label
=
I18nObject
(
en_US
=
model
),
model_type
=
ModelType
.
TEXT_EMBEDDING
,
fetch_from
=
FetchFrom
.
CUSTOMIZABLE_MODEL
,
model_properties
=
{
ModelPropertyKey
.
CONTEXT_SIZE
:
int
(
credentials
.
get
(
'context_size'
)),
ModelPropertyKey
.
MAX_CHUNKS
:
1
,
},
parameter_rules
=
[],
pricing
=
PriceConfig
(
input
=
Decimal
(
credentials
.
get
(
'input_price'
,
0
)),
unit
=
Decimal
(
credentials
.
get
(
'unit'
,
0
)),
currency
=
credentials
.
get
(
'currency'
,
"USD"
)
)
)
return
entity
def
_calc_response_usage
(
self
,
model
:
str
,
credentials
:
dict
,
tokens
:
int
)
->
EmbeddingUsage
:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info
=
self
.
get_price
(
model
=
model
,
credentials
=
credentials
,
price_type
=
PriceType
.
INPUT
,
tokens
=
tokens
)
# transform usage
usage
=
EmbeddingUsage
(
tokens
=
tokens
,
total_tokens
=
tokens
,
unit_price
=
input_price_info
.
unit_price
,
price_unit
=
input_price_info
.
unit
,
total_price
=
input_price_info
.
total_amount
,
currency
=
input_price_info
.
currency
,
latency
=
time
.
perf_counter
()
-
self
.
started_at
)
return
usage
@
property
def
_invoke_error_mapping
(
self
)
->
dict
[
type
[
InvokeError
],
list
[
type
[
Exception
]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return
{
InvokeAuthorizationError
:
[
requests
.
exceptions
.
InvalidHeader
,
# Missing or Invalid API Key
],
InvokeBadRequestError
:
[
requests
.
exceptions
.
HTTPError
,
# Invalid Endpoint URL or model name
requests
.
exceptions
.
InvalidURL
,
# Misconfigured request or other API error
],
InvokeRateLimitError
:
[
requests
.
exceptions
.
RetryError
# Too many requests sent in a short period of time
],
InvokeServerUnavailableError
:
[
requests
.
exceptions
.
ConnectionError
,
# Engine Overloaded
requests
.
exceptions
.
HTTPError
# Server Error
],
InvokeConnectionError
:
[
requests
.
exceptions
.
ConnectTimeout
,
# Timeout
requests
.
exceptions
.
ReadTimeout
# Timeout
]
}
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
View file @
cca9edc9
...
...
@@ -360,6 +360,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message
=
AssistantPromptMessage
(
content
=
""
),
finish_reason
=
"Non-JSON encountered."
)
break
if
not
chunk_json
or
len
(
chunk_json
[
'choices'
])
==
0
:
continue
...
...
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml
View file @
cca9edc9
...
...
@@ -33,8 +33,8 @@ model_credential_schema:
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
Base URL, eg. https://api.openai.com/v1
en_US
:
Base URL, eg. https://api.openai.com/v1
zh_Hans
:
Base URL, e
.
g. https://api.openai.com/v1
en_US
:
Base URL, e
.
g. https://api.openai.com/v1
-
variable
:
mode
show_on
:
-
variable
:
__model_type
...
...
api/core/model_runtime/model_providers/openllm/openllm.yaml
View file @
cca9edc9
...
...
@@ -33,5 +33,5 @@ model_credential_schema:
type
:
text-input
required
:
true
placeholder
:
zh_Hans
:
在此输入OpenLLM的服务器地址,如 http
s://example.com/xxx
en_US
:
Enter the url of your OpenLLM,
for example https://example.com/xxx
zh_Hans
:
在此输入OpenLLM的服务器地址,如 http
://192.168.1.100:3000
en_US
:
Enter the url of your OpenLLM,
e.g. http://192.168.1.100:3000
api/core/model_runtime/model_providers/xinference/xinference.yaml
View file @
cca9edc9
...
...
@@ -34,8 +34,8 @@ model_credential_schema:
type
:
secret-input
required
:
true
placeholder
:
zh_Hans
:
在此输入Xinference的服务器地址,如 http
s://example.com/xxx
en_US
:
Enter the url of your Xinference,
for example https://example.com/xxx
zh_Hans
:
在此输入Xinference的服务器地址,如 http
://192.168.1.100:9997
en_US
:
Enter the url of your Xinference,
e.g. http://192.168.1.100:9997
-
variable
:
model_uid
label
:
zh_Hans
:
模型UID
...
...
api/core/prompt/prompt_transform.py
View file @
cca9edc9
...
...
@@ -121,6 +121,7 @@ class PromptTransform:
prompt_template_entity
=
prompt_template_entity
,
inputs
=
inputs
,
query
=
query
,
files
=
files
,
context
=
context
,
memory
=
memory
,
model_config
=
model_config
...
...
@@ -343,7 +344,14 @@ class PromptTransform:
prompt_message
=
UserPromptMessage
(
content
=
prompt_message_contents
)
else
:
prompt_message
=
UserPromptMessage
(
content
=
prompt
)
if
files
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
prompt
)]
for
file
in
files
:
prompt_message_contents
.
append
(
file
.
prompt_message_content
)
prompt_message
=
UserPromptMessage
(
content
=
prompt_message_contents
)
else
:
prompt_message
=
UserPromptMessage
(
content
=
prompt
)
return
[
prompt_message
]
...
...
@@ -434,6 +442,7 @@ class PromptTransform:
prompt_template_entity
:
PromptTemplateEntity
,
inputs
:
dict
,
query
:
str
,
files
:
List
[
FileObj
],
context
:
Optional
[
str
],
memory
:
Optional
[
TokenBufferMemory
],
model_config
:
ModelConfigEntity
)
->
List
[
PromptMessage
]:
...
...
@@ -461,7 +470,14 @@ class PromptTransform:
prompt
=
self
.
_format_prompt
(
prompt_template
,
prompt_inputs
)
prompt_messages
.
append
(
UserPromptMessage
(
content
=
prompt
))
if
files
:
prompt_message_contents
=
[
TextPromptMessageContent
(
data
=
prompt
)]
for
file
in
files
:
prompt_message_contents
.
append
(
file
.
prompt_message_content
)
prompt_messages
.
append
(
UserPromptMessage
(
content
=
prompt_message_contents
))
else
:
prompt_messages
.
append
(
UserPromptMessage
(
content
=
prompt
))
return
prompt_messages
...
...
api/tests/integration_tests/.env.example
View file @
cca9edc9
...
...
@@ -62,5 +62,8 @@ COHERE_API_KEY=
# Jina Credentials
JINA_API_KEY=
# Ollama Credentials
OLLAMA_BASE_URL=
# Mock Switch
MOCK_SWITCH=false
\ No newline at end of file
api/tests/integration_tests/model_runtime/ollama/__init__.py
0 → 100644
View file @
cca9edc9
api/tests/integration_tests/model_runtime/ollama/test_llm.py
0 → 100644
View file @
cca9edc9
import
os
from
typing
import
Generator
import
pytest
from
core.model_runtime.entities.message_entities
import
AssistantPromptMessage
,
UserPromptMessage
,
\
SystemPromptMessage
,
TextPromptMessageContent
,
ImagePromptMessageContent
from
core.model_runtime.entities.llm_entities
import
LLMResult
,
LLMResultChunkDelta
,
\
LLMResultChunk
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.ollama.llm.llm
import
OllamaLargeLanguageModel
def
test_validate_credentials
():
model
=
OllamaLargeLanguageModel
()
with
pytest
.
raises
(
CredentialsValidateFailedError
):
model
.
validate_credentials
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
'http://localhost:21434'
,
'mode'
:
'chat'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
}
)
model
.
validate_credentials
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
}
)
def
test_invoke_model
():
model
=
OllamaLargeLanguageModel
()
response
=
model
.
invoke
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
},
prompt_messages
=
[
UserPromptMessage
(
content
=
'Who are you?'
)
],
model_parameters
=
{
'temperature'
:
1.0
,
'top_k'
:
2
,
'top_p'
:
0.5
,
'num_predict'
:
10
},
stop
=
[
'How'
],
stream
=
False
)
assert
isinstance
(
response
,
LLMResult
)
assert
len
(
response
.
message
.
content
)
>
0
def
test_invoke_stream_model
():
model
=
OllamaLargeLanguageModel
()
response
=
model
.
invoke
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
},
prompt_messages
=
[
SystemPromptMessage
(
content
=
'You are a helpful AI assistant.'
,
),
UserPromptMessage
(
content
=
'Who are you?'
)
],
model_parameters
=
{
'temperature'
:
1.0
,
'top_k'
:
2
,
'top_p'
:
0.5
,
'num_predict'
:
10
},
stop
=
[
'How'
],
stream
=
True
)
assert
isinstance
(
response
,
Generator
)
for
chunk
in
response
:
assert
isinstance
(
chunk
,
LLMResultChunk
)
assert
isinstance
(
chunk
.
delta
,
LLMResultChunkDelta
)
assert
isinstance
(
chunk
.
delta
.
message
,
AssistantPromptMessage
)
def
test_invoke_completion_model
():
model
=
OllamaLargeLanguageModel
()
response
=
model
.
invoke
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'completion'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
},
prompt_messages
=
[
UserPromptMessage
(
content
=
'Who are you?'
)
],
model_parameters
=
{
'temperature'
:
1.0
,
'top_k'
:
2
,
'top_p'
:
0.5
,
'num_predict'
:
10
},
stop
=
[
'How'
],
stream
=
False
)
assert
isinstance
(
response
,
LLMResult
)
assert
len
(
response
.
message
.
content
)
>
0
def
test_invoke_stream_completion_model
():
model
=
OllamaLargeLanguageModel
()
response
=
model
.
invoke
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'completion'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
},
prompt_messages
=
[
SystemPromptMessage
(
content
=
'You are a helpful AI assistant.'
,
),
UserPromptMessage
(
content
=
'Who are you?'
)
],
model_parameters
=
{
'temperature'
:
1.0
,
'top_k'
:
2
,
'top_p'
:
0.5
,
'num_predict'
:
10
},
stop
=
[
'How'
],
stream
=
True
)
assert
isinstance
(
response
,
Generator
)
for
chunk
in
response
:
assert
isinstance
(
chunk
,
LLMResultChunk
)
assert
isinstance
(
chunk
.
delta
,
LLMResultChunkDelta
)
assert
isinstance
(
chunk
.
delta
.
message
,
AssistantPromptMessage
)
def
test_invoke_completion_model_with_vision
():
model
=
OllamaLargeLanguageModel
()
result
=
model
.
invoke
(
model
=
'llava'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'completion'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
},
prompt_messages
=
[
UserPromptMessage
(
content
=
[
TextPromptMessageContent
(
data
=
'What is this in this picture?'
,
),
ImagePromptMessageContent
(
data
=
''
)
]
)
],
model_parameters
=
{
'temperature'
:
0.1
,
'num_predict'
:
100
},
stream
=
False
,
)
assert
isinstance
(
result
,
LLMResult
)
assert
len
(
result
.
message
.
content
)
>
0
def
test_invoke_chat_model_with_vision
():
model
=
OllamaLargeLanguageModel
()
result
=
model
.
invoke
(
model
=
'llava'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
},
prompt_messages
=
[
UserPromptMessage
(
content
=
[
TextPromptMessageContent
(
data
=
'What is this in this picture?'
,
),
ImagePromptMessageContent
(
data
=
''
)
]
)
],
model_parameters
=
{
'temperature'
:
0.1
,
'num_predict'
:
100
},
stream
=
False
,
)
assert
isinstance
(
result
,
LLMResult
)
assert
len
(
result
.
message
.
content
)
>
0
def
test_get_num_tokens
():
model
=
OllamaLargeLanguageModel
()
num_tokens
=
model
.
get_num_tokens
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
2048
,
'max_tokens'
:
2048
,
},
prompt_messages
=
[
UserPromptMessage
(
content
=
'Hello World!'
)
]
)
assert
isinstance
(
num_tokens
,
int
)
assert
num_tokens
==
6
api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py
0 → 100644
View file @
cca9edc9
import
os
import
pytest
from
core.model_runtime.entities.text_embedding_entities
import
TextEmbeddingResult
from
core.model_runtime.errors.validate
import
CredentialsValidateFailedError
from
core.model_runtime.model_providers.ollama.text_embedding.text_embedding
import
OllamaEmbeddingModel
def
test_validate_credentials
():
model
=
OllamaEmbeddingModel
()
with
pytest
.
raises
(
CredentialsValidateFailedError
):
model
.
validate_credentials
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
'http://localhost:21434'
,
'mode'
:
'chat'
,
'context_size'
:
4096
,
}
)
model
.
validate_credentials
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
4096
,
}
)
def
test_invoke_model
():
model
=
OllamaEmbeddingModel
()
result
=
model
.
invoke
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
4096
,
},
texts
=
[
"hello"
,
"world"
],
user
=
"abc-123"
)
assert
isinstance
(
result
,
TextEmbeddingResult
)
assert
len
(
result
.
embeddings
)
==
2
assert
result
.
usage
.
total_tokens
==
2
def
test_get_num_tokens
():
model
=
OllamaEmbeddingModel
()
num_tokens
=
model
.
get_num_tokens
(
model
=
'mistral:text'
,
credentials
=
{
'base_url'
:
os
.
environ
.
get
(
'OLLAMA_BASE_URL'
),
'mode'
:
'chat'
,
'context_size'
:
4096
,
},
texts
=
[
"hello"
,
"world"
]
)
assert
num_tokens
==
2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment