Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
3ea8d7a0
Unverified
Commit
3ea8d7a0
authored
Aug 20, 2023
by
takatost
Committed by
GitHub
Aug 20, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add openllm support (#928)
parent
da3f10a5
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
412 additions
and
3 deletions
+412
-3
model_provider_factory.py
api/core/model_providers/model_provider_factory.py
+3
-0
openllm_model.py
api/core/model_providers/models/llm/openllm_model.py
+60
-0
openllm_provider.py
api/core/model_providers/providers/openllm_provider.py
+137
-0
_providers.json
api/core/model_providers/rules/_providers.json
+2
-1
openllm.json
api/core/model_providers/rules/openllm.json
+7
-0
requirements.txt
api/requirements.txt
+2
-1
.env.example
api/tests/integration_tests/.env.example
+4
-1
test_openllm_model.py
api/tests/integration_tests/models/llm/test_openllm_model.py
+72
-0
test_openllm_provider.py
...tests/unit_tests/model_providers/test_openllm_provider.py
+125
-0
No files found.
api/core/model_providers/model_provider_factory.py
View file @
3ea8d7a0
...
...
@@ -60,6 +60,9 @@ class ModelProviderFactory:
elif
provider_name
==
'xinference'
:
from
core.model_providers.providers.xinference_provider
import
XinferenceProvider
return
XinferenceProvider
elif
provider_name
==
'openllm'
:
from
core.model_providers.providers.openllm_provider
import
OpenLLMProvider
return
OpenLLMProvider
else
:
raise
NotImplementedError
...
...
api/core/model_providers/models/llm/openllm_model.py
0 → 100644
View file @
3ea8d7a0
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
OpenLLM
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
class
OpenLLMModel
(
BaseLLM
):
model_mode
:
ModelMode
=
ModelMode
.
COMPLETION
def
_init_client
(
self
)
->
Any
:
self
.
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
client
=
OpenLLM
(
server_url
=
self
.
credentials
.
get
(
'server_url'
),
callbacks
=
self
.
callbacks
,
**
self
.
provider_model_kwargs
)
return
client
def
_run
(
self
,
messages
:
List
[
PromptMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
->
LLMResult
:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
self
.
_client
.
generate
([
prompts
],
stop
,
callbacks
)
def
get_num_tokens
(
self
,
messages
:
List
[
PromptMessage
])
->
int
:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts
=
self
.
_get_prompt_from_messages
(
messages
)
return
max
(
self
.
_client
.
get_num_tokens
(
prompts
),
0
)
def
_set_model_kwargs
(
self
,
model_kwargs
:
ModelKwargs
):
pass
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"OpenLLM: {str(ex)}"
)
@
classmethod
def
support_streaming
(
cls
):
return
False
api/core/model_providers/providers/openllm_provider.py
0 → 100644
View file @
3ea8d7a0
import
json
from
typing
import
Type
from
langchain.llms
import
OpenLLM
from
core.helper
import
encrypter
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
from
core.model_providers.models.llm.openllm_model
import
OpenLLMModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
models.provider
import
ProviderType
class
OpenLLMProvider
(
BaseModelProvider
):
@
property
def
provider_name
(
self
):
"""
Returns the name of a provider.
"""
return
'openllm'
def
_get_fixed_model_list
(
self
,
model_type
:
ModelType
)
->
list
[
dict
]:
return
[]
def
get_model_class
(
self
,
model_type
:
ModelType
)
->
Type
[
BaseProviderModel
]:
"""
Returns the model class.
:param model_type:
:return:
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
OpenLLMModel
else
:
raise
NotImplementedError
return
model_class
def
get_model_parameter_rules
(
self
,
model_name
:
str
,
model_type
:
ModelType
)
->
ModelKwargsRules
:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
128
),
)
@
classmethod
def
is_model_credentials_valid_or_raise
(
cls
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if
'server_url'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'OpenLLM Server URL must be provided.'
)
try
:
credential_kwargs
=
{
'server_url'
:
credentials
[
'server_url'
]
}
llm
=
OpenLLM
(
max_tokens
=
10
,
**
credential_kwargs
)
llm
(
"ping"
)
except
Exception
as
ex
:
raise
CredentialsValidateFailedError
(
str
(
ex
))
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials
[
'server_url'
]
=
encrypter
.
encrypt_token
(
tenant_id
,
credentials
[
'server_url'
])
return
credentials
def
get_model_credentials
(
self
,
model_name
:
str
,
model_type
:
ModelType
,
obfuscated
:
bool
=
False
)
->
dict
:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if
self
.
provider
.
provider_type
!=
ProviderType
.
CUSTOM
.
value
:
raise
NotImplementedError
provider_model
=
self
.
_get_provider_model
(
model_name
,
model_type
)
if
not
provider_model
.
encrypted_config
:
return
{
'server_url'
:
None
}
credentials
=
json
.
loads
(
provider_model
.
encrypted_config
)
if
credentials
[
'server_url'
]:
credentials
[
'server_url'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
credentials
[
'server_url'
]
)
if
obfuscated
:
credentials
[
'server_url'
]
=
encrypter
.
obfuscated_token
(
credentials
[
'server_url'
])
return
credentials
@
classmethod
def
is_provider_credentials_valid_or_raise
(
cls
,
credentials
:
dict
):
return
@
classmethod
def
encrypt_provider_credentials
(
cls
,
tenant_id
:
str
,
credentials
:
dict
)
->
dict
:
return
{}
def
get_provider_credentials
(
self
,
obfuscated
:
bool
=
False
)
->
dict
:
return
{}
api/core/model_providers/rules/_providers.json
View file @
3ea8d7a0
...
...
@@ -9,5 +9,6 @@
"chatglm"
,
"replicate"
,
"huggingface_hub"
,
"xinference"
"xinference"
,
"openllm"
]
\ No newline at end of file
api/core/model_providers/rules/openllm.json
0 → 100644
View file @
3ea8d7a0
{
"support_provider_types"
:
[
"custom"
],
"system_config"
:
null
,
"model_flexibility"
:
"configurable"
}
\ No newline at end of file
api/requirements.txt
View file @
3ea8d7a0
...
...
@@ -49,4 +49,5 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.0
\ No newline at end of file
xinference==0.2.0
openllm~=0.2.26
\ No newline at end of file
api/tests/integration_tests/.env.example
View file @
3ea8d7a0
...
...
@@ -36,4 +36,7 @@ CHATGLM_API_BASE=
# Xinference Credentials
XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=
\ No newline at end of file
XINFERENCE_MODEL_UID=
# OpenLLM Credentials
OPENLLM_SERVER_URL=
\ No newline at end of file
api/tests/integration_tests/models/llm/test_openllm_model.py
0 → 100644
View file @
3ea8d7a0
import
json
import
os
from
unittest.mock
import
patch
,
MagicMock
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelType
from
core.model_providers.models.llm.openllm_model
import
OpenLLMModel
from
core.model_providers.providers.openllm_provider
import
OpenLLMProvider
from
models.provider
import
Provider
,
ProviderType
,
ProviderModel
def
get_mock_provider
():
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'openllm'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
''
,
is_valid
=
True
,
)
def
get_mock_model
(
model_name
,
mocker
):
model_kwargs
=
ModelKwargs
(
max_tokens
=
10
,
temperature
=
0.01
)
server_url
=
os
.
environ
[
'OPENLLM_SERVER_URL'
]
model_provider
=
OpenLLMProvider
(
provider
=
get_mock_provider
())
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
provider_name
=
'openllm'
,
model_name
=
model_name
,
model_type
=
ModelType
.
TEXT_GENERATION
.
value
,
encrypted_config
=
json
.
dumps
({
'server_url'
:
server_url
}),
is_valid
=
True
,
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
return
OpenLLMModel
(
model_provider
=
model_provider
,
name
=
model_name
,
model_kwargs
=
model_kwargs
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_num_tokens
(
mock_decrypt
,
mocker
):
model
=
get_mock_model
(
'facebook/opt-125m'
,
mocker
)
rst
=
model
.
get_num_tokens
([
PromptMessage
(
type
=
MessageType
.
HUMAN
,
content
=
'Who is your manufacturer?'
)
])
assert
rst
==
5
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_run
(
mock_decrypt
,
mocker
):
mocker
.
patch
(
'core.model_providers.providers.base.BaseModelProvider.update_last_used'
,
return_value
=
None
)
model
=
get_mock_model
(
'facebook/opt-125m'
,
mocker
)
messages
=
[
PromptMessage
(
content
=
'Human: who are you?
\n
Answer: '
)]
rst
=
model
.
run
(
messages
)
assert
len
(
rst
.
content
)
>
0
api/tests/unit_tests/model_providers/test_openllm_provider.py
0 → 100644
View file @
3ea8d7a0
import
pytest
from
unittest.mock
import
patch
,
MagicMock
import
json
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.base
import
CredentialsValidateFailedError
from
core.model_providers.providers.openllm_provider
import
OpenLLMProvider
from
models.provider
import
ProviderType
,
Provider
,
ProviderModel
PROVIDER_NAME
=
'openllm'
MODEL_PROVIDER_CLASS
=
OpenLLMProvider
VALIDATE_CREDENTIAL
=
{
'server_url'
:
'http://127.0.0.1:3333/'
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
return
f
'encrypted_{encrypt_key}'
def
decrypt_side_effect
(
tenant_id
,
encrypted_key
):
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
def
test_is_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'langchain.llms.openllm.OpenLLM._identifying_params'
,
return_value
=
None
)
mocker
.
patch
(
'langchain.llms.openllm.OpenLLM._call'
,
return_value
=
"abc"
)
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'username/test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
VALIDATE_CREDENTIAL
.
copy
()
)
def
test_is_credentials_valid_or_raise_invalid
(
mocker
):
mocker
.
patch
(
'langchain.llms.openllm.OpenLLM._identifying_params'
,
return_value
=
None
)
# raise CredentialsValidateFailedError if credential is not in credentials
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
{}
)
# raise CredentialsValidateFailedError if credential is invalid
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
{
'server_url'
:
'invalid'
})
@
patch
(
'core.helper.encrypter.encrypt_token'
,
side_effect
=
encrypt_side_effect
)
def
test_encrypt_model_credentials
(
mock_encrypt
):
api_key
=
'http://127.0.0.1:3333/'
result
=
MODEL_PROVIDER_CLASS
.
encrypt_model_credentials
(
tenant_id
=
'tenant_id'
,
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
VALIDATE_CREDENTIAL
.
copy
()
)
mock_encrypt
.
assert_called_with
(
'tenant_id'
,
api_key
)
assert
result
[
'server_url'
]
==
f
'encrypted_{api_key}'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_model_credentials_custom
(
mock_decrypt
,
mocker
):
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
None
,
is_valid
=
True
,
)
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'server_url'
]
=
'encrypted_'
+
encrypted_credential
[
'server_url'
]
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
encrypted_config
=
json
.
dumps
(
encrypted_credential
)
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_model_credentials
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
)
assert
result
[
'server_url'
]
==
'http://127.0.0.1:3333/'
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_get_model_credentials_obfuscated
(
mock_decrypt
,
mocker
):
provider
=
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
PROVIDER_NAME
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
None
,
is_valid
=
True
,
)
encrypted_credential
=
VALIDATE_CREDENTIAL
.
copy
()
encrypted_credential
[
'server_url'
]
=
'encrypted_'
+
encrypted_credential
[
'server_url'
]
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
encrypted_config
=
json
.
dumps
(
encrypted_credential
)
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
model_provider
=
MODEL_PROVIDER_CLASS
(
provider
=
provider
)
result
=
model_provider
.
get_model_credentials
(
model_name
=
'test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
obfuscated
=
True
)
middle_token
=
result
[
'server_url'
][
6
:
-
2
]
assert
len
(
middle_token
)
==
max
(
len
(
VALIDATE_CREDENTIAL
[
'server_url'
])
-
8
,
0
)
assert
all
(
char
==
'*'
for
char
in
middle_token
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment