Unverified Commit 3fa5204b authored by takatost's avatar takatost Committed by GitHub

feat: optimize performance (#1928)

parent 5a756ca9
...@@ -10,6 +10,7 @@ from pydantic import BaseModel ...@@ -10,6 +10,7 @@ from pydantic import BaseModel
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
...@@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel): ...@@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel):
db.session.add(provider_record) db.session.add(provider_record)
db.session.commit() db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(ProviderType.CUSTOM) self.switch_preferred_provider_type(ProviderType.CUSTOM)
def delete_custom_credentials(self) -> None: def delete_custom_credentials(self) -> None:
...@@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel): ...@@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel):
db.session.delete(provider_record) db.session.delete(provider_record)
db.session.commit() db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]: -> Optional[dict]:
""" """
...@@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel): ...@@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel):
db.session.add(provider_model_record) db.session.add(provider_model_record)
db.session.commit() db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
""" """
Delete custom model credentials. Delete custom model credentials.
...@@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel): ...@@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel):
db.session.delete(provider_model_record) db.session.delete(provider_model_record)
db.session.commit() db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def get_provider_instance(self) -> ModelProvider: def get_provider_instance(self) -> ModelProvider:
""" """
Get provider instance. Get provider instance.
......
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider"
MODEL = "provider_model"
class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return cached_provider_credentials
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 3600, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)
...@@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel): ...@@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
class ModelProviderFactory: class ModelProviderFactory:
model_provider_extensions: dict[str, ModelProviderExtension] = None model_provider_extensions: dict[str, ModelProviderExtension] = None
def __init__(self) -> None:
# for cache in memory
self.get_providers()
def get_providers(self) -> list[ProviderEntity]: def get_providers(self) -> list[ProviderEntity]:
""" """
Get all providers Get all providers
......
This diff is collapsed.
...@@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple ...@@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
import requests import requests
from flask import current_app from flask import current_app
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity from core.entities.model_entities import ModelStatus
from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment