Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
8b15b742
Unverified
Commit
8b15b742
authored
Mar 13, 2024
by
Bowen Liang
Committed by
GitHub
Mar 13, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
generalize position helper for parsing _position.yaml and sorting objects by name (#2803)
parent
849dc056
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
95 additions
and
46 deletions
+95
-46
extensible.py
api/core/extension/extensible.py
+8
-6
ai_model.py
api/core/model_runtime/model_providers/__base/ai_model.py
+3
-11
model_provider_factory.py
...e/model_runtime/model_providers/model_provider_factory.py
+6
-16
_positions.py
api/core/tools/provider/builtin/_positions.py
+8
-13
position_helper.py
api/core/utils/position_helper.py
+70
-0
No files found.
api/core/extension/extensible.py
View file @
8b15b742
...
...
@@ -3,11 +3,12 @@ import importlib.util
import
json
import
logging
import
os
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
from
pydantic
import
BaseModel
from
core.utils.position_helper
import
sort_to_dict_by_position_map
class
ExtensionModule
(
enum
.
Enum
):
MODERATION
=
'moderation'
...
...
@@ -36,7 +37,8 @@ class Extensible:
@
classmethod
def
scan_extensions
(
cls
):
extensions
=
{}
extensions
:
list
[
ModuleExtension
]
=
[]
position_map
=
{}
# get the path of the current class
current_path
=
os
.
path
.
abspath
(
cls
.
__module__
.
replace
(
"."
,
os
.
path
.
sep
)
+
'.py'
)
...
...
@@ -63,6 +65,7 @@ class Extensible:
if
os
.
path
.
exists
(
builtin_file_path
):
with
open
(
builtin_file_path
,
encoding
=
'utf-8'
)
as
f
:
position
=
int
(
f
.
read
()
.
strip
())
position_map
[
extension_name
]
=
position
if
(
extension_name
+
'.py'
)
not
in
file_names
:
logging
.
warning
(
f
"Missing {extension_name}.py file in {subdir_path}, Skip."
)
...
...
@@ -96,16 +99,15 @@ class Extensible:
with
open
(
json_path
,
encoding
=
'utf-8'
)
as
f
:
json_data
=
json
.
load
(
f
)
extensions
[
extension_name
]
=
ModuleExtension
(
extensions
.
append
(
ModuleExtension
(
extension_class
=
extension_class
,
name
=
extension_name
,
label
=
json_data
.
get
(
'label'
),
form_schema
=
json_data
.
get
(
'form_schema'
),
builtin
=
builtin
,
position
=
position
)
)
)
sorted_items
=
sorted
(
extensions
.
items
(),
key
=
lambda
x
:
(
x
[
1
]
.
position
is
None
,
x
[
1
]
.
position
))
sorted_extensions
=
OrderedDict
(
sorted_items
)
sorted_extensions
=
sort_to_dict_by_position_map
(
position_map
,
extensions
,
lambda
x
:
x
.
name
)
return
sorted_extensions
api/core/model_runtime/model_providers/__base/ai_model.py
View file @
8b15b742
...
...
@@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import (
)
from
core.model_runtime.errors.invoke
import
InvokeAuthorizationError
,
InvokeError
from
core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier
import
GPT2Tokenizer
from
core.utils.position_helper
import
get_position_map
,
sort_by_position_map
class
AIModel
(
ABC
):
...
...
@@ -148,15 +149,7 @@ class AIModel(ABC):
]
# get _position.yaml file path
position_file_path
=
os
.
path
.
join
(
provider_model_type_path
,
'_position.yaml'
)
# read _position.yaml file
position_map
=
{}
if
os
.
path
.
exists
(
position_file_path
):
with
open
(
position_file_path
,
encoding
=
'utf-8'
)
as
f
:
positions
=
yaml
.
safe_load
(
f
)
# convert list to dict with key as model provider name, value as index
position_map
=
{
position
:
index
for
index
,
position
in
enumerate
(
positions
)}
position_map
=
get_position_map
(
provider_model_type_path
)
# traverse all model_schema_yaml_paths
for
model_schema_yaml_path
in
model_schema_yaml_paths
:
...
...
@@ -206,8 +199,7 @@ class AIModel(ABC):
model_schemas
.
append
(
model_schema
)
# resort model schemas by position
if
position_map
:
model_schemas
.
sort
(
key
=
lambda
x
:
position_map
.
get
(
x
.
model
,
999
))
model_schemas
=
sort_by_position_map
(
position_map
,
model_schemas
,
lambda
x
:
x
.
model
)
# cache model schemas
self
.
model_schemas
=
model_schemas
...
...
api/core/model_runtime/model_providers/model_provider_factory.py
View file @
8b15b742
import
importlib
import
logging
import
os
from
collections
import
OrderedDict
from
typing
import
Optional
import
yaml
from
pydantic
import
BaseModel
from
core.model_runtime.entities.model_entities
import
ModelType
...
...
@@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
from
core.model_runtime.model_providers.__base.model_provider
import
ModelProvider
from
core.model_runtime.schema_validators.model_credential_schema_validator
import
ModelCredentialSchemaValidator
from
core.model_runtime.schema_validators.provider_credential_schema_validator
import
ProviderCredentialSchemaValidator
from
core.utils.position_helper
import
get_position_map
,
sort_to_dict_by_position_map
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -200,7 +199,6 @@ class ModelProviderFactory:
if
self
.
model_provider_extensions
:
return
self
.
model_provider_extensions
model_providers
=
{}
# get the path of current classes
current_path
=
os
.
path
.
abspath
(
__file__
)
...
...
@@ -215,17 +213,10 @@ class ModelProviderFactory:
]
# get _position.yaml file path
position_file_path
=
os
.
path
.
join
(
model_providers_path
,
'_position.yaml'
)
# read _position.yaml file
position_map
=
{}
if
os
.
path
.
exists
(
position_file_path
):
with
open
(
position_file_path
,
encoding
=
'utf-8'
)
as
f
:
positions
=
yaml
.
safe_load
(
f
)
# convert list to dict with key as model provider name, value as index
position_map
=
{
position
:
index
for
index
,
position
in
enumerate
(
positions
)}
position_map
=
get_position_map
(
model_providers_path
)
# traverse all model_provider_dir_paths
model_providers
:
list
[
ModelProviderExtension
]
=
[]
for
model_provider_dir_path
in
model_provider_dir_paths
:
# get model_provider dir name
model_provider_name
=
os
.
path
.
basename
(
model_provider_dir_path
)
...
...
@@ -256,14 +247,13 @@ class ModelProviderFactory:
logger
.
warning
(
f
"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip."
)
continue
model_providers
[
model_provider_name
]
=
ModelProviderExtension
(
model_providers
.
append
(
ModelProviderExtension
(
name
=
model_provider_name
,
provider_instance
=
model_provider_class
(),
position
=
position_map
.
get
(
model_provider_name
)
)
)
)
sorted_items
=
sorted
(
model_providers
.
items
(),
key
=
lambda
x
:
(
x
[
1
]
.
position
is
None
,
x
[
1
]
.
position
))
sorted_extensions
=
OrderedDict
(
sorted_items
)
sorted_extensions
=
sort_to_dict_by_position_map
(
position_map
,
model_providers
,
lambda
x
:
x
.
name
)
self
.
model_provider_extensions
=
sorted_extensions
...
...
api/core/tools/provider/builtin/_positions.py
View file @
8b15b742
import
os.path
from
yaml
import
FullLoader
,
load
from
core.tools.entities.user_entities
import
UserToolProvider
from
core.utils.position_helper
import
get_position_map
,
sort_by_position_map
class
BuiltinToolProviderSort
:
...
...
@@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
@
classmethod
def
sort
(
cls
,
providers
:
list
[
UserToolProvider
])
->
list
[
UserToolProvider
]:
if
not
cls
.
_position
:
tmp_position
=
{}
file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'_position.yaml'
)
with
open
(
file_path
)
as
f
:
for
pos
,
val
in
enumerate
(
load
(
f
,
Loader
=
FullLoader
)):
tmp_position
[
val
]
=
pos
cls
.
_position
=
tmp_position
cls
.
_position
=
get_position_map
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
))
def
sort_compare
(
provider
:
UserToolProvider
)
->
int
:
def
name_func
(
provider
:
UserToolProvider
)
->
str
:
if
provider
.
type
==
UserToolProvider
.
ProviderType
.
MODEL
:
return
cls
.
_position
.
get
(
f
'model.{provider.name}'
,
10000
)
return
cls
.
_position
.
get
(
provider
.
name
,
10000
)
sorted_providers
=
sorted
(
providers
,
key
=
sort_compare
)
return
f
'model.{provider.name}'
else
:
return
provider
.
name
sorted_providers
=
sort_by_position_map
(
cls
.
_position
,
providers
,
name_func
)
return
sorted_providers
\ No newline at end of file
api/core/utils/position_helper.py
0 → 100644
View file @
8b15b742
import
logging
import
os
from
collections
import
OrderedDict
from
collections.abc
import
Callable
from
typing
import
Any
,
AnyStr
import
yaml
def
get_position_map
(
folder_path
:
AnyStr
,
file_name
:
str
=
'_position.yaml'
,
)
->
dict
[
str
,
int
]:
"""
Get the mapping from name to index from a YAML file
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try
:
position_file_name
=
os
.
path
.
join
(
folder_path
,
file_name
)
if
not
os
.
path
.
exists
(
position_file_name
):
return
{}
with
open
(
position_file_name
,
encoding
=
'utf-8'
)
as
f
:
positions
=
yaml
.
safe_load
(
f
)
position_map
=
{}
for
index
,
name
in
enumerate
(
positions
):
if
name
and
isinstance
(
name
,
str
):
position_map
[
name
.
strip
()]
=
index
return
position_map
except
:
logging
.
warning
(
f
'Failed to load the YAML position file {folder_path}/{file_name}.'
)
return
{}
def
sort_by_position_map
(
position_map
:
dict
[
str
,
int
],
data
:
list
[
Any
],
name_func
:
Callable
[[
Any
],
str
],
)
->
list
[
Any
]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: the sorted objects
"""
if
not
position_map
or
not
data
:
return
data
return
sorted
(
data
,
key
=
lambda
x
:
position_map
.
get
(
name_func
(
x
),
float
(
'inf'
)))
def
sort_to_dict_by_position_map
(
position_map
:
dict
[
str
,
int
],
data
:
list
[
Any
],
name_func
:
Callable
[[
Any
],
str
],
)
->
OrderedDict
[
str
,
Any
]:
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items
=
sort_by_position_map
(
position_map
,
data
,
name_func
)
return
OrderedDict
([(
name_func
(
item
),
item
)
for
item
in
sorted_items
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment