Unverified Commit 8b15b742 authored by Bowen Liang's avatar Bowen Liang Committed by GitHub

generalize position helper for parsing _position.yaml and sorting objects by name (#2803)

parent 849dc056
...@@ -3,11 +3,12 @@ import importlib.util ...@@ -3,11 +3,12 @@ import importlib.util
import json import json
import logging import logging
import os import os
from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.utils.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum): class ExtensionModule(enum.Enum):
MODERATION = 'moderation' MODERATION = 'moderation'
...@@ -36,7 +37,8 @@ class Extensible: ...@@ -36,7 +37,8 @@ class Extensible:
@classmethod @classmethod
def scan_extensions(cls): def scan_extensions(cls):
extensions = {} extensions: list[ModuleExtension] = []
position_map = {}
# get the path of the current class # get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
...@@ -63,6 +65,7 @@ class Extensible: ...@@ -63,6 +65,7 @@ class Extensible:
if os.path.exists(builtin_file_path): if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f: with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip()) position = int(f.read().strip())
position_map[extension_name] = position
if (extension_name + '.py') not in file_names: if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
...@@ -96,16 +99,15 @@ class Extensible: ...@@ -96,16 +99,15 @@ class Extensible:
with open(json_path, encoding='utf-8') as f: with open(json_path, encoding='utf-8') as f:
json_data = json.load(f) json_data = json.load(f)
extensions[extension_name] = ModuleExtension( extensions.append(ModuleExtension(
extension_class=extension_class, extension_class=extension_class,
name=extension_name, name=extension_name,
label=json_data.get('label'), label=json_data.get('label'),
form_schema=json_data.get('form_schema'), form_schema=json_data.get('form_schema'),
builtin=builtin, builtin=builtin,
position=position position=position
) ))
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position)) sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
sorted_extensions = OrderedDict(sorted_items)
return sorted_extensions return sorted_extensions
...@@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import ( ...@@ -18,6 +18,7 @@ from core.model_runtime.entities.model_entities import (
) )
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.utils.position_helper import get_position_map, sort_by_position_map
class AIModel(ABC): class AIModel(ABC):
...@@ -148,15 +149,7 @@ class AIModel(ABC): ...@@ -148,15 +149,7 @@ class AIModel(ABC):
] ]
# get _position.yaml file path # get _position.yaml file path
position_file_path = os.path.join(provider_model_type_path, '_position.yaml') position_map = get_position_map(provider_model_type_path)
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
# traverse all model_schema_yaml_paths # traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths: for model_schema_yaml_path in model_schema_yaml_paths:
...@@ -206,8 +199,7 @@ class AIModel(ABC): ...@@ -206,8 +199,7 @@ class AIModel(ABC):
model_schemas.append(model_schema) model_schemas.append(model_schema)
# resort model schemas by position # resort model schemas by position
if position_map: model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
# cache model schemas # cache model schemas
self.model_schemas = model_schemas self.model_schemas = model_schemas
......
import importlib import importlib
import logging import logging
import os import os
from collections import OrderedDict
from typing import Optional from typing import Optional
import yaml
from pydantic import BaseModel from pydantic import BaseModel
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
...@@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid ...@@ -12,6 +10,7 @@ from core.model_runtime.entities.provider_entities import ProviderConfig, Provid
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -200,7 +199,6 @@ class ModelProviderFactory: ...@@ -200,7 +199,6 @@ class ModelProviderFactory:
if self.model_provider_extensions: if self.model_provider_extensions:
return self.model_provider_extensions return self.model_provider_extensions
model_providers = {}
# get the path of current classes # get the path of current classes
current_path = os.path.abspath(__file__) current_path = os.path.abspath(__file__)
...@@ -215,17 +213,10 @@ class ModelProviderFactory: ...@@ -215,17 +213,10 @@ class ModelProviderFactory:
] ]
# get _position.yaml file path # get _position.yaml file path
position_file_path = os.path.join(model_providers_path, '_position.yaml') position_map = get_position_map(model_providers_path)
# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
# traverse all model_provider_dir_paths # traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []
for model_provider_dir_path in model_provider_dir_paths: for model_provider_dir_path in model_provider_dir_paths:
# get model_provider dir name # get model_provider dir name
model_provider_name = os.path.basename(model_provider_dir_path) model_provider_name = os.path.basename(model_provider_dir_path)
...@@ -256,14 +247,13 @@ class ModelProviderFactory: ...@@ -256,14 +247,13 @@ class ModelProviderFactory:
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
continue continue
model_providers[model_provider_name] = ModelProviderExtension( model_providers.append(ModelProviderExtension(
name=model_provider_name, name=model_provider_name,
provider_instance=model_provider_class(), provider_instance=model_provider_class(),
position=position_map.get(model_provider_name) position=position_map.get(model_provider_name)
) ))
sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position)) sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)
sorted_extensions = OrderedDict(sorted_items)
self.model_provider_extensions = sorted_extensions self.model_provider_extensions = sorted_extensions
......
import os.path import os.path
from yaml import FullLoader, load
from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map
class BuiltinToolProviderSort: class BuiltinToolProviderSort:
...@@ -11,18 +10,14 @@ class BuiltinToolProviderSort: ...@@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
@classmethod @classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position: if not cls._position:
tmp_position = {} cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
cls._position = tmp_position
def sort_compare(provider: UserToolProvider) -> int: def name_func(provider: UserToolProvider) -> str:
if provider.type == UserToolProvider.ProviderType.MODEL: if provider.type == UserToolProvider.ProviderType.MODEL:
return cls._position.get(f'model.{provider.name}', 10000) return f'model.{provider.name}'
return cls._position.get(provider.name, 10000) else:
return provider.name
sorted_providers = sorted(providers, key=sort_compare)
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers return sorted_providers
\ No newline at end of file
import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr
import yaml
def get_position_map(
folder_path: AnyStr,
file_name: str = '_position.yaml',
) -> dict[str, int]:
"""
Get the mapping from name to index from a YAML file
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}
with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}
def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: the sorted objects
"""
if not position_map or not data:
return data
return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment