Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
b7c29ea1
Unverified
Commit
b7c29ea1
authored
Aug 16, 2023
by
takatost
Committed by
GitHub
Aug 16, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize model when app create (#875)
parent
cc2d71c2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
104 additions
and
23 deletions
+104
-23
api-unit-tests.yml
.github/workflows/api-unit-tests.yml
+38
-0
app.py
api/controllers/console/app/app.py
+43
-18
test_huggingface_hub_provider.py
...it_tests/model_providers/test_huggingface_hub_provider.py
+3
-2
test_replicate_provider.py
...sts/unit_tests/model_providers/test_replicate_provider.py
+19
-2
test_tongyi_provider.py
api/tests/unit_tests/model_providers/test_tongyi_provider.py
+1
-1
No files found.
.github/workflows/api-unit-tests.yml
0 → 100644
View file @
b7c29ea1
name
:
Run Pytest
on
:
pull_request
:
branches
:
-
main
push
:
branches
:
-
deploy/dev
jobs
:
test
:
runs-on
:
ubuntu-latest
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v2
-
name
:
Set up Python
uses
:
actions/setup-python@v2
with
:
python-version
:
'
3.10'
-
name
:
Cache pip dependencies
uses
:
actions/cache@v2
with
:
path
:
~/.cache/pip
key
:
${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys
:
${{ runner.os }}-pip-
-
name
:
Install dependencies
run
:
|
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
-
name
:
Run pytest
run
:
pytest api/tests/unit_tests
api/controllers/console/app/app.py
View file @
b7c29ea1
# -*- coding:utf-8 -*-
# -*- coding:utf-8 -*-
import
json
import
json
import
logging
from
datetime
import
datetime
from
datetime
import
datetime
from
flask_login
import
login_required
,
current_user
from
flask_login
import
login_required
,
current_user
...
@@ -11,7 +12,9 @@ from controllers.console import api
...
@@ -11,7 +12,9 @@ from controllers.console import api
from
controllers.console.app.error
import
AppNotFoundError
,
ProviderNotInitializeError
from
controllers.console.app.error
import
AppNotFoundError
,
ProviderNotInitializeError
from
controllers.console.setup
import
setup_required
from
controllers.console.setup
import
setup_required
from
controllers.console.wraps
import
account_initialization_required
from
controllers.console.wraps
import
account_initialization_required
from
core.model_providers.error
import
ProviderTokenNotInitError
,
LLMBadRequestError
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_provider_factory
import
ModelProviderFactory
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.models.entity.model_params
import
ModelType
from
events.app_event
import
app_was_created
,
app_was_deleted
from
events.app_event
import
app_was_created
,
app_was_deleted
from
libs.helper
import
TimestampField
from
libs.helper
import
TimestampField
...
@@ -124,24 +127,34 @@ class AppListApi(Resource):
...
@@ -124,24 +127,34 @@ class AppListApi(Resource):
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
if
current_user
.
current_tenant
.
current_role
not
in
[
'admin'
,
'owner'
]:
raise
Forbidden
()
raise
Forbidden
()
default_model
=
ModelFactory
.
get_default_model
(
try
:
tenant_id
=
current_user
.
current_tenant_id
,
default_model
=
ModelFactory
.
get_text_generation_model
(
model_type
=
ModelType
.
TEXT_GENERATION
tenant_id
=
current_user
.
current_tenant_id
)
)
except
(
ProviderTokenNotInitError
,
LLMBadRequestError
):
if
default_model
:
default_model
=
None
default_model_provider
=
default_model
.
provider_name
except
Exception
as
e
:
default_model_name
=
default_model
.
model_name
logging
.
exception
(
e
)
else
:
default_model
=
None
raise
ProviderNotInitializeError
(
f
"No Text Generation Model available. Please configure a valid provider "
f
"in the Settings -> Model Provider."
)
if
args
[
'model_config'
]
is
not
None
:
if
args
[
'model_config'
]
is
not
None
:
# validate config
# validate config
model_config_dict
=
args
[
'model_config'
]
model_config_dict
=
args
[
'model_config'
]
model_config_dict
[
"model"
][
"provider"
]
=
default_model_provider
model_config_dict
[
"model"
][
"name"
]
=
default_model_name
# get model provider
model_provider
=
ModelProviderFactory
.
get_preferred_model_provider
(
current_user
.
current_tenant_id
,
model_config_dict
[
"model"
][
"provider"
]
)
if
not
model_provider
:
if
not
default_model
:
raise
ProviderNotInitializeError
(
f
"No Default System Reasoning Model available. Please configure "
f
"in the Settings -> Model Provider."
)
else
:
model_config_dict
[
"model"
][
"provider"
]
=
default_model
.
model_provider
.
provider_name
model_config_dict
[
"model"
][
"name"
]
=
default_model
.
name
model_configuration
=
AppModelConfigService
.
validate_configuration
(
model_configuration
=
AppModelConfigService
.
validate_configuration
(
tenant_id
=
current_user
.
current_tenant_id
,
tenant_id
=
current_user
.
current_tenant_id
,
...
@@ -169,10 +182,22 @@ class AppListApi(Resource):
...
@@ -169,10 +182,22 @@ class AppListApi(Resource):
app
=
App
(
**
model_config_template
[
'app'
])
app
=
App
(
**
model_config_template
[
'app'
])
app_model_config
=
AppModelConfig
(
**
model_config_template
[
'model_config'
])
app_model_config
=
AppModelConfig
(
**
model_config_template
[
'model_config'
])
model_dict
=
app_model_config
.
model_dict
# get model provider
model_dict
[
'provider'
]
=
default_model_provider
model_provider
=
ModelProviderFactory
.
get_preferred_model_provider
(
model_dict
[
'name'
]
=
default_model_name
current_user
.
current_tenant_id
,
app_model_config
.
model
=
json
.
dumps
(
model_dict
)
app_model_config
.
model_dict
[
"provider"
]
)
if
not
model_provider
:
if
not
default_model
:
raise
ProviderNotInitializeError
(
f
"No Default System Reasoning Model available. Please configure "
f
"in the Settings -> Model Provider."
)
else
:
model_dict
=
app_model_config
.
model_dict
model_dict
[
'provider'
]
=
default_model
.
model_provider
.
provider_name
model_dict
[
'name'
]
=
default_model
.
name
app_model_config
.
model
=
json
.
dumps
(
model_dict
)
app
.
name
=
args
[
'name'
]
app
.
name
=
args
[
'name'
]
app
.
mode
=
args
[
'mode'
]
app
.
mode
=
args
[
'mode'
]
...
...
api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py
View file @
b7c29ea1
...
@@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key):
...
@@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key):
@
patch
(
'huggingface_hub.hf_api.ModelInfo'
)
@
patch
(
'huggingface_hub.hf_api.ModelInfo'
)
def
test_hosted_inference_api_is_credentials_valid_or_raise_valid
(
mock_model_info
,
mocker
):
def
test_hosted_inference_api_is_credentials_valid_or_raise_valid
(
mock_model_info
,
mocker
):
mock_model_info
.
return_value
=
MagicMock
(
pipeline_tag
=
'text2text-generation'
)
mock_model_info
.
return_value
=
MagicMock
(
pipeline_tag
=
'text2text-generation'
,
cardData
=
{
'inference'
:
True
})
mocker
.
patch
(
'langchain.llms.huggingface_hub.HuggingFaceHub._call'
,
return_value
=
"abc"
)
mocker
.
patch
(
'huggingface_hub.hf_api.HfApi.whoami'
,
return_value
=
"abc"
)
mocker
.
patch
(
'huggingface_hub.hf_api.HfApi.model_info'
,
return_value
=
mock_model_info
.
return_value
)
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'test_model_name'
,
model_name
=
'test_model_name'
,
...
...
api/tests/unit_tests/model_providers/test_replicate_provider.py
View file @
b7c29ea1
...
@@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key):
...
@@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key):
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
return
encrypted_key
.
replace
(
'encrypted_'
,
''
)
def
version_effect
(
id
:
str
):
mock_version
=
MagicMock
()
mock_version
.
openapi_schema
=
{
'components'
:
{
'schemas'
:
{
'Output'
:
{
'items'
:
{
'type'
:
'string'
}
}
}
}
}
return
mock_version
@
patch
(
'replicate.version.VersionCollection.get'
,
side_effect
=
version_effect
)
def
test_is_credentials_valid_or_raise_valid
(
mocker
):
def
test_is_credentials_valid_or_raise_valid
(
mocker
):
mock_query
=
MagicMock
()
mock_query
=
MagicMock
()
mock_query
.
return_value
=
None
mock_query
.
return_value
=
None
mocker
.
patch
(
'replicate.model.ModelCollection.get'
,
return_value
=
mock_query
)
mocker
.
patch
(
'replicate.model.ModelCollection.get'
,
return_value
=
mock_query
)
mocker
.
patch
(
'replicate.model.Model.versions'
,
return_value
=
mock_query
)
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'test_model_name'
,
model_name
=
'
username/
test_model_name'
,
model_type
=
ModelType
.
TEXT_GENERATION
,
model_type
=
ModelType
.
TEXT_GENERATION
,
credentials
=
VALIDATE_CREDENTIAL
.
copy
()
credentials
=
VALIDATE_CREDENTIAL
.
copy
()
)
)
...
...
api/tests/unit_tests/model_providers/test_tongyi_provider.py
View file @
b7c29ea1
...
@@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
...
@@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
def
test_is_provider_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'
langchain.llms.tongyi.
Tongyi._generate'
,
return_value
=
LLMResult
(
generations
=
[[
Generation
(
text
=
"abc"
)]]))
mocker
.
patch
(
'
core.third_party.langchain.llms.tongyi_llm.Enhance
Tongyi._generate'
,
return_value
=
LLMResult
(
generations
=
[[
Generation
(
text
=
"abc"
)]]))
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
MODEL_PROVIDER_CLASS
.
is_provider_credentials_valid_or_raise
(
VALIDATE_CREDENTIAL
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment