Unverified Commit b7c29ea1 authored by takatost's avatar takatost Committed by GitHub

feat: optimize model when app create (#875)

parent cc2d71c2
name: Run Pytest
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/unit_tests
# -*- coding:utf-8 -*-
import json
import logging
from datetime import datetime
from flask_login import login_required, current_user
......@@ -11,7 +12,9 @@ from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField
......@@ -124,24 +127,34 @@ class AppListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
try:
default_model = ModelFactory.get_text_generation_model(
tenant_id=current_user.current_tenant_id
)
if default_model:
default_model_provider = default_model.provider_name
default_model_name = default_model.model_name
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None
except Exception as e:
logging.exception(e)
default_model = None
if args['model_config'] is not None:
# validate config
model_config_dict = args['model_config']
model_config_dict["model"]["provider"] = default_model_provider
model_config_dict["model"]["name"] = default_model_name
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
model_config_dict["model"]["provider"]
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
model_config_dict["model"]["name"] = default_model.name
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
......@@ -169,9 +182,21 @@ class AppListApi(Resource):
app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config'])
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
app_model_config.model_dict["provider"]
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model_provider
model_dict['name'] = default_model_name
model_dict['provider'] = default_model.model_provider.provider_name
model_dict['name'] = default_model.name
app_model_config.model = json.dumps(model_dict)
app.name = args['name']
......
......@@ -30,8 +30,9 @@ def decrypt_side_effect(tenant_id, encrypted_key):
@patch('huggingface_hub.hf_api.ModelInfo')
def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker):
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation')
mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc")
mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation', cardData={'inference': True})
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value="abc")
mocker.patch('huggingface_hub.hf_api.HfApi.model_info', return_value=mock_model_info.return_value)
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
......
......@@ -23,14 +23,31 @@ def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def version_effect(id: str):
mock_version = MagicMock()
mock_version.openapi_schema = {
'components': {
'schemas': {
'Output': {
'items': {
'type': 'string'
}
}
}
}
}
return mock_version
@patch('replicate.version.VersionCollection.get', side_effect=version_effect)
def test_is_credentials_valid_or_raise_valid(mocker):
mock_query = MagicMock()
mock_query.return_value = None
mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
mocker.patch('replicate.model.Model.versions', return_value=mock_query)
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_name='username/test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)
......
......@@ -26,7 +26,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
mocker.patch('core.third_party.langchain.llms.tongyi_llm.EnhanceTongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment