Unverified Commit 5fa2161b authored by takatost's avatar takatost Committed by GitHub

feat: server multi models support (#799)

parent d8b712b3
......@@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path):
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']
for root, _, files in os.walk("."):
for file in files:
......
......@@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret
# Hosted Model Credentials
HOSTED_OPENAI_ENABLED=false
HOSTED_OPENAI_API_KEY=
HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1
HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
HOSTED_AZURE_OPENAI_API_BASE=
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
\ No newline at end of file
......@@ -16,8 +16,9 @@ from flask import Flask, request, Response, session
import flask_login
from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail
ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db
from extensions.ext_login import login_manager
......@@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app)
register_commands(app)
core.init_app(app)
hosted.init_app(app)
return app
......@@ -88,6 +89,7 @@ def initialize_extensions(app):
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
def _create_tenant_for_account(account):
......@@ -246,5 +248,18 @@ def threads():
}
@app.route('/db-pool-stat')
def pool_stat():
engine = db.engine
return {
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001)
import datetime
import logging
import math
import random
import string
import time
......@@ -9,18 +9,18 @@ from flask import current_app
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from core.model_providers.providers.hosted import hosted_model_providers
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
from models.dataset import Dataset, DatasetQuery, Document
from models.model import Account
import secrets
import base64
from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
from models.provider import Provider, ProviderType, ProviderQuotaType
@click.command('reset-password', help='Reset the account password.')
......@@ -251,26 +251,37 @@ def clean_unused_dataset_indexes():
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
if not hosted_model_providers.anthropic:
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
return
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
providers = db.session.query(Provider).filter(
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
page += 1
for tenant in tenants:
for provider in providers:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
original_quota_limit = provider.quota_limit
new_quota_limit = hosted_model_providers.anthropic.quota_limit
division = math.ceil(new_quota_limit / 1000)
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
else original_quota_limit * division
provider.quota_used = division * provider.quota_used
db.session.commit()
count += 1
except Exception as e:
click.echo(click.style(
......
......@@ -41,6 +41,7 @@ DEFAULTS = {
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600,
'SQLALCHEMY_ECHO': 'False',
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
......@@ -50,9 +51,16 @@ DEFAULTS = {
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30
}
......@@ -182,7 +190,10 @@ class Config:
}
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))}
self.SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
}
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
......@@ -194,20 +205,35 @@ class Config:
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
# hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
......
......@@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers
from .workspace import workspace, members, model_providers, account, tool_providers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio
# Import webhook controllers
from .webhook import stripe
......@@ -2,16 +2,17 @@
import json
from datetime import datetime
import flask
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from werkzeug.exceptions import Unauthorized, Forbidden
from werkzeug.exceptions import Forbidden
from constants.model_template import model_templates, demo_model_templates
from controllers.console import api
from controllers.console.app.error import AppNotFoundError
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField
from extensions.ext_database import db
......@@ -126,9 +127,9 @@ class AppListApi(Resource):
if args['model_config'] is not None:
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=args['model_config'],
mode=args['mode']
config=args['model_config']
)
app = App(
......@@ -164,6 +165,21 @@ class AppListApi(Resource):
app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config'])
default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)
if default_model:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.provider_name
model_dict['name'] = default_model.model_name
app_model_config.model = json.dumps(model_dict)
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
app.name = args['name']
app.mode = args['mode']
app.icon = args['icon']
......
......@@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService
......
......@@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from flask_restful import Resource, reqparse
......@@ -41,8 +41,11 @@ class CompletionMessageApi(Resource):
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
account = flask_login.current_user
try:
......@@ -51,7 +54,7 @@ class CompletionMessageApi(Resource):
user=account,
args=args,
from_source='console',
streaming=True,
streaming=streaming,
is_model_config_override=True
)
......@@ -111,8 +114,11 @@ class ChatMessageApi(Resource):
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
account = flask_login.current_user
try:
......@@ -121,7 +127,7 @@ class ChatMessageApi(Resource):
user=account,
args=args,
from_source='console',
streaming=True,
streaming=streaming,
is_model_config_override=True
)
......
......@@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.generator.llm_generator import LLMGenerator
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
......
......@@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination
......
......@@ -28,9 +28,9 @@ class ModelConfigResource(Resource):
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app_model.mode
config=request.json
)
new_app_model_config = AppModelConfig(
......
......@@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource):
# validate args
DocumentService.estimate_args_validate(args)
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule'])
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
return response, 200
......
......@@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
......@@ -97,6 +100,15 @@ class DatasetListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
......@@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource):
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form'])
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else:
raise ValueError('Data source type not support')
return response, 200
......
......@@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from libs.helper import TimestampField
from extensions.ext_database import db
......@@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource):
# validate args
DocumentService.document_create_args_validate(args)
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex:
......@@ -319,6 +330,15 @@ class DatasetInitApi(Resource):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
args = parser.parse_args()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
# validate args
DocumentService.document_create_args_validate(args)
......@@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict)
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
return response
......@@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict)
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
elif dataset.data_source_type:
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(info_list,
data_process_rule_dict)
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
else:
raise ValueError('Data source type not support')
return response
......
......@@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import TimestampField
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
......@@ -102,6 +102,8 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
......
......@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
......
......@@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from services.completion_service import CompletionService
......
......@@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService
......
......@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import InstalledApp
......@@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters."""
app_model = installed_app.app
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}
......
......@@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
......
......@@ -12,9 +12,8 @@ from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.constant import llm_constant
from core.conversation_message_task import PubHandler
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
from libs.helper import uuid_value
from services.completion_service import CompletionService
......@@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource):
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('provider', type=str, required=True, location='json')
parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json')
args = parser.parse_args()
......@@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource):
# update app model config
args['model_config'] = app_model_config.to_dict()
args['model_config']['model']['name'] = args['model']
if not llm_constant.models[args['model']]:
raise ValueError("Model not exists.")
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
args['model_config']['model']['provider'] = args['provider']
args['model_config']['agent_mode']['tools'] = args['tools']
if not args['model_config']['agent_mode']['tools']:
......
......@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError
......
......@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
......@@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource):
"""Retrieve app parameters."""
app_model = universal_app
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
}
......
import logging
import stripe
from flask import request, current_app
from flask_restful import Resource
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import only_edition_cloud
from services.provider_checkout_service import ProviderCheckoutService
class StripeWebhookApi(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400
# Handle the checkout.session.completed event
if event['type'] == 'checkout.session.completed':
logging.debug(event['data']['object']['id'])
logging.debug(event['data']['object']['amount_subtotal'])
logging.debug(event['data']['object']['currency'])
logging.debug(event['data']['object']['payment_intent'])
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])
# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()
try:
provider_checkout_service.fulfill_provider_order(event)
except Exception as e:
logging.debug(str(e))
return 'success', 200
return 'success', 200
api.add_resource(StripeWebhookApi, '/webhook/stripe')
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from models.provider import ProviderType
from services.provider_service import ProviderService
class DefaultModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
provider_service = ProviderService()
default_model = provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type=args['model_type']
)
if not default_model:
return None
model_provider = ModelProviderFactory.get_preferred_model_provider(
tenant_id,
default_model.provider_name
)
if not model_provider:
return {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': default_model.provider_name
}
}
provider = model_provider.provider
rst = {
'model_name': default_model.model_name,
'model_type': default_model.model_type,
'model_provider': {
'provider_name': provider.provider_name,
'provider_type': provider.provider_type
}
}
model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
if provider.provider_type == ProviderType.SYSTEM.value:
rst['model_provider']['quota_type'] = provider.quota_type
rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
rst['model_provider']['quota_limit'] = provider.quota_limit
rst['model_provider']['quota_used'] = provider.quota_used
return rst
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=args['model_type'],
provider_name=args['provider_name'],
model_name=args['model_name']
)
return {'result': 'success'}
class ValidModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, model_type):
ModelType.value_of(model_type)
provider_service = ProviderService()
valid_models = provider_service.get_valid_model_list(
tenant_id=current_user.current_tenant_id,
model_type=model_type
)
return valid_models
api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')
# -*- coding:utf-8 -*-
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType
from services.provider_service import ProviderService
class ProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
provider_service = ProviderService()
provider_info_list = provider_service.get_provider_list(tenant_id)
provider_list = [
{
'provider_name': p['provider_name'],
'provider_type': p['provider_type'],
'is_valid': p['is_valid'],
'last_used': p['last_used'],
'is_enabled': p['is_valid'],
**({
'quota_type': p['quota_type'],
'quota_limit': p['quota_limit'],
'quota_used': p['quota_used']
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
if p['config'] else None
}
for name, provider_info in provider_info_list.items()
for p in provider_info['providers']
]
return provider_list
class ProviderTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
provider_service = ProviderService()
try:
provider_service.save_custom_provider_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
......@@ -30,7 +30,7 @@ tenant_fields = {
'created_at': TimestampField,
'role': fields.String,
'providers': fields.List(fields.Nested(provider_fields)),
'in_trail': fields.Boolean,
'in_trial': fields.Boolean,
'trial_end_reason': fields.String,
}
......
......@@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
......@@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource):
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}
......
......@@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from models.model import App, AppModelConfig
from services.audio_service import AudioService
......
......@@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError
from controllers.service_api.wraps import AppApiResource
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from services.completion_service import CompletionService
......
......@@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError
from controllers.service_api.wraps import DatasetApiResource
from core.llm.error import ProviderTokenNotInitError
from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import UploadFile
......
......@@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields
from controllers.web import api
from controllers.web.wraps import WebApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
......@@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource):
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}
......
......@@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
......
......@@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from services.completion_service import CompletionService
......
......@@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService
......
import os
from typing import Optional
import langchain
from flask import Flask
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.prompt.prompt_template import OneLineFormatter
class HostedOpenAICredential(BaseModel):
api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
from typing import cast, List
from typing import List
from langchain import OpenAI
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage
from core.constant import llm_constant
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
class CalcTokenMixin:
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
llm = cast(ChatOpenAI, llm)
return llm.get_num_tokens_from_messages(messages)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
return model_instance.get_num_tokens(to_prompt_messages(messages))
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
......@@ -22,10 +19,9 @@ class CalcTokenMixin:
:param messages:
:return:
"""
llm = cast(ChatOpenAI, llm)
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
completion_max_tokens = llm.max_tokens
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
llm_max_tokens = model_instance.model_rules.max_tokens.max
completion_max_tokens = model_instance.model_kwargs.max_tokens
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
return rest_tokens
......
......@@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
......@@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
......
......@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
......@@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
......
......@@ -3,20 +3,28 @@ from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
......
......@@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
......@@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
......
import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.dataset_tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try:
return self.output_parser.parse(full_output)
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
unique_tool_names = set(tool.name for tool in tools)
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
dataset_tools=tools,
**kwargs,
)
......@@ -14,7 +14,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......@@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
......@@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
......
......@@ -3,7 +3,6 @@ import logging
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
......@@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
......@@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
llm: BaseLanguageModel
model_instance: BaseLLM
tools: list[BaseTool]
summary_llm: BaseLanguageModel
dataset_llm: BaseLanguageModel
summary_model_instance: BaseLLM
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
......@@ -60,36 +61,49 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
summary_llm=self.configuration.summary_model_instance.client,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.dataset_llm,
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
......
......@@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.llm.base import BaseLLM
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_name = model_name
self.model_instant = model_instant
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
......@@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
self._message_agent_thought, self.model_instant, self._current_loop
)
self._agent_loops.append(self._current_loop)
......@@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
self._message_agent_thought, self.model_instant, self._current_loop
)
self._agent_loops.append(self._current_loop)
......
......@@ -3,18 +3,20 @@ import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
from langchain.schema import LLMResult, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: BaseLanguageModel,
def __init__(self, model_instance: BaseLLM,
conversation_message_task: ConversationMessageTask):
self.llm = llm
self.model_instance = model_instance
self.llm_message = LLMMessage()
self.start_at = None
self.conversation_message_task = conversation_message_task
......@@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
......@@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
......@@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
......@@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else:
logging.error(error)
......@@ -5,9 +5,7 @@ from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.conversation_message_task import ConversationMessageTask
......
This diff is collapsed.
from _decimal import Decimal
models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens
'gpt-3.5-turbo-16k': 'openai', # 16384 tokens
'text-davinci-003': 'openai', # 4,097 tokens
'text-davinci-002': 'openai', # 4,097 tokens
'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
}
max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
'text-davinci-002': 4097,
'text-curie-001': 2049,
'text-babbage-001': 2049,
'text-ada-001': 2049,
'text-embedding-ada-002': 8191,
}
models_by_mode = {
'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
],
'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
'text-davinci-003', # 4,097 tokens
'text-davinci-002' # 4,097 tokens
'text-curie-001', # 2,049 tokens
'text-babbage-001', # 2,049 tokens
'text-ada-001' # 2,049 tokens
],
'embedding': [
'text-embedding-ada-002' # 8191 tokens, 1536 dimensions
]
}
model_currency = 'USD'
model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': {
'prompt': Decimal('0.03'),
'completion': Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': Decimal('0.06'),
'completion': Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': Decimal('0.0015'),
'completion': Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': Decimal('0.003'),
'completion': Decimal('0.004')
},
'text-davinci-003': {
'prompt': Decimal('0.02'),
'completion': Decimal('0.02')
},
'text-curie-001': {
'prompt': Decimal('0.002'),
'completion': Decimal('0.002')
},
'text-babbage-001': {
'prompt': Decimal('0.0005'),
'completion': Decimal('0.0005')
},
'text-ada-001': {
'prompt': Decimal('0.0004'),
'completion': Decimal('0.0004')
},
'text-embedding-ada-002': {
'usage': Decimal('0.0001'),
}
}
agent_model_name = 'text-davinci-003'
......@@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.provider.llm_provider_service import LLMProviderService
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from events.message_event import message_was_created
......@@ -16,12 +16,11 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
from models.provider import ProviderType, Provider
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.task_id = task_id
......@@ -38,9 +37,12 @@ class ConversationMessageTask:
self.conversation = conversation
self.is_new_conversation = False
self.model_instance = model_instance
self.message = None
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name')
self.mode = app.mode
......@@ -56,9 +58,6 @@ class ConversationMessageTask:
)
def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name
override_model_configs = None
if self.is_override:
override_model_configs = {
......@@ -89,15 +88,19 @@ class ConversationMessageTask:
if self.app_model_config.pre_prompt:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
model_provider_name=self.provider_name,
model_name=self.model_name
)
system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_model_config_id=self.app_model_config.id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
......@@ -117,7 +120,7 @@ class ConversationMessageTask:
self.message = Message(
app_id=self.app_model_config.app_id,
model_provider=self.model_dict.get('provider'),
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=self.conversation.id,
......@@ -131,7 +134,7 @@ class ConversationMessageTask:
answer_unit_price=0,
provider_response_latency=0,
total_price=0,
currency=llm_constant.model_currency,
currency=self.model_instance.get_currency(),
from_source=('console' if isinstance(self.user, Account) else 'api'),
from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
from_account_id=(self.user.id if isinstance(self.user, Account) else None),
......@@ -145,12 +148,10 @@ class ConversationMessageTask:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
model_name = self.app_model_config.model_dict.get('name')
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = llm_constant.model_prices[model_name]['prompt']
answer_unit_price = llm_constant.model_prices[model_name]['completion']
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
......@@ -163,8 +164,6 @@ class ConversationMessageTask:
self.message.provider_response_latency = llm_message.latency
self.message.total_price = total_price
self.update_provider_quota()
db.session.commit()
message_was_created.send(
......@@ -176,20 +175,6 @@ class ConversationMessageTask:
if not by_stopped:
self.end()
def update_provider_quota(self):
llm_provider_service = LLMProviderService(
tenant_id=self.app.tenant_id,
provider_name=self.message.model_provider,
)
provider = llm_provider_service.get_provider_db_record()
if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1})
def init_chain(self, chain_result: ChainResult):
message_chain = MessageChain(
message_id=self.message.id,
......@@ -229,10 +214,10 @@ class ConversationMessageTask:
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
......@@ -253,7 +238,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = llm_constant.model_currency
message_agent_thought.currency = agent_model_instant.get_currency()
db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
......
......@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence
from langchain.schema import Document
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
......@@ -13,12 +13,10 @@ class DatesetDocumentStore:
self,
dataset: Dataset,
user_id: str,
embedding_model_name: str,
document_id: Optional[str] = None,
):
self._dataset = dataset
self._user_id = user_id
self._embedding_model_name = embedding_model_name
self._document_id = document_id
@classmethod
......@@ -39,10 +37,6 @@ class DatesetDocumentStore:
def user_id(self) -> Any:
return self._user_id
@property
def embedding_model_name(self) -> Any:
return self._embedding_model_name
@property
def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
......@@ -74,6 +68,10 @@ class DatesetDocumentStore:
if max_position is None:
max_position = 0
embedding_model = ModelFactory.get_embedding_model(
tenant_id=self._dataset.tenant_id
)
for doc in docs:
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
......@@ -88,7 +86,7 @@ class DatesetDocumentStore:
)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
tokens = embedding_model.get_num_tokens(doc.page_content)
if not segment_document:
max_position += 1
......
......@@ -4,14 +4,14 @@ from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from core.model_providers.models.embedding.base import BaseEmbedding
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
def __init__(self, embeddings: BaseEmbedding):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
......@@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings):
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
if embedding_queue_texts:
try:
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
finally:
i += 1
text_embeddings.extend(embedding_results)
text_embeddings.extend(embedding_results)
return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
if embedding:
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
try:
embedding_results = self._embeddings.client.embed_query(text)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)
try:
embedding = Embedding(hash=hash)
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
......@@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings):
logging.exception('Failed to add embedding to db')
return embedding_results
import logging
from langchain import PromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_open_ai import StreamableOpenAI
from core.llm.token_calculator import TokenCalculator
from langchain.schema import OutputParserException
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
......@@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
# gpt-3.5-turbo works not well
generate_base_model = 'text-davinci-003'
class LLMGenerator:
@classmethod
......@@ -28,29 +22,35 @@ class LLMGenerator:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
max_tokens=50,
timeout=600
model_kwargs=ModelKwargs(
max_tokens=50
)
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model = 'gpt-3.5-turbo'
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
......@@ -68,25 +68,16 @@ class LLMGenerator:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=model,
max_tokens=max_tokens
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
......@@ -94,16 +85,13 @@ class LLMGenerator:
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
if isinstance(llm, BaseChatModel):
prompt = [HumanMessage(content=prompt)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
......@@ -119,23 +107,19 @@ class LLMGenerator:
_input = prompt.format_prompt(histories=histories)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=256
model_kwargs=ModelKwargs(
max_tokens=256,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
if isinstance(output, BaseMessage):
output = output.content
questions = output_parser.parse(output)
output = model_instance.run(prompts)
questions = output_parser.parse(output.content)
except Exception:
logging.exception("Error generating suggested questions after answer")
questions = []
......@@ -160,21 +144,19 @@ class LLMGenerator:
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
llm: StreamableOpenAI = LLMBuilder.to_llm(
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_name=generate_base_model,
temperature=0,
max_tokens=512
model_kwargs=ModelKwargs(
max_tokens=512,
temperature=0
)
)
if isinstance(llm, BaseChatModel):
query = [HumanMessage(content=_input.to_string())]
else:
query = _input.to_string()
prompts = [PromptMessage(content=_input.to_string())]
try:
output = llm(query)
rule_config = output_parser.parse(output)
output = model_instance.run(prompts)
rule_config = output_parser.parse(output.content)
except OutputParserException:
raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
except Exception:
......@@ -188,25 +170,21 @@ class LLMGenerator:
return rule_config
@classmethod
async def generate_qa_document(cls, llm: StreamableOpenAI, query):
def generate_qa_document(cls, tenant_id: str, query):
prompt = GENERATOR_QA_PROMPT
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=2000
)
)
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
response = llm.generate([prompt])
answer = response.generations[0][0].text
return answer.strip()
@classmethod
def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
prompt = GENERATOR_QA_PROMPT
if isinstance(llm, BaseChatModel):
prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
prompts = [
PromptMessage(content=prompt, type=MessageType.SYSTEM),
PromptMessage(content=query)
]
response = llm.generate([prompt])
answer = response.generations[0][0].text
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
import base64
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
def obfuscated_token(token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def encrypt_token(tenant_id: str, token: str):
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(tenant_id: str, token: str):
return rsa.decrypt(base64.b64decode(token), tenant_id)
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from core.model_providers.model_factory import ModelFactory
from models.dataset import Dataset
......@@ -15,16 +14,11 @@ class IndexBuilder:
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
max_retries=1,
**model_credentials
))
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,
......
This diff is collapsed.
from typing import Union, Optional, List
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType, ProviderName
class LLMBuilder:
"""
This class handles the following logic:
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
OPENAI_API_TYPE=azure
OPENAI_API_VERSION=2022-12-01
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
OPENAI_API_KEY=<your Azure OpenAI API key>
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
}
model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
return cls.to_llm(
tenant_id=tenant_id,
model_name=model_name,
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_mode_by_model(cls, model_name: str) -> str:
if not model_name:
raise ValueError(f"empty model name is not supported.")
if model_name in llm_constant.models_by_mode['chat']:
return "chat"
elif model_name in llm_constant.models_by_mode['completion']:
return "completion"
else:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
# model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider_name = llm_constant.models[model_name]
if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = azure_openai_provider
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
provider_name = provider.provider_name
return provider_name
import openai
from models.provider import ProviderName
class Moderation:
def __init__(self, provider: str, api_key: str):
self.provider = provider
self.api_key = api_key
if self.provider == ProviderName.OPENAI.value:
self.client = openai.Moderation
def moderate(self, text):
return self.client.create(input=text, api_key=self.api_key)
import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return [
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'anthropic_api_key': ''
}
if obfuscated:
if not config.get('anthropic_api_key'):
config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
import json
import logging
from typing import Optional, Union
import openai
import requests
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
return []
def check_embedding_model(self, credentials: Optional[dict] = None):
credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials
try:
result = openai.Embedding.create(input=['test'],
engine='text-embedding-ada-002',
timeout=60,
api_key=str(credentials.get('openai_api_key')),
api_base=str(credentials.get('openai_api_base')),
api_type='azure',
api_version=str(credentials.get('openai_api_version')))["data"][0][
"embedding"]
except openai.error.AuthenticationError as e:
raise AzureAuthenticationError(str(e))
except openai.error.APIConnectionError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`')
except openai.error.InvalidRequestError as e:
if e.http_status == 404:
raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' "
"deployment name is exists in Azure AI")
else:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
except openai.error.OpenAIError as e:
raise AzureRequestFailedError(
'Failed to request Azure OpenAI. cause: {}'.format(str(e)))
if not isinstance(result, list):
raise AzureRequestFailedError('Failed to request Azure OpenAI.')
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Azure OpenAI as a dictionary.
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 16
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):
return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
if obfuscated:
if not config.get('openai_api_key'):
config = {
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '',
'openai_api_key': ''
}
config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
return config
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'openai_api_version' not in config:
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
self.check_embedding_model(credentials=config)
except ValidateFailedError as e:
raise e
except AzureAuthenticationError:
raise ValidateFailedError('Validation failed, please check your API Key.')
except AzureRequestFailedError as ex:
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
except Exception as ex:
logging.exception('Azure OpenAI Credentials validation failed')
raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'openai_api_type': 'azure',
'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
return config
class AzureAuthenticationError(Exception):
pass
class AzureRequestFailedError(Exception):
pass
import base64
from abc import ABC, abstractmethod
from typing import Optional, Union
from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.provider import Provider, ProviderType, ProviderName
class BaseProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
"""
provider = self.get_provider(only_custom)
if not provider:
raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if model_id and model_id == 'gpt-4':
raise ModelCurrentlyNotSupportError()
if quota_used >= quota_limit:
raise QuotaExceededError()
return self.get_hosted_credentials()
else:
return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
)
if provider_name:
query = query.filter(Provider.provider_name == provider_name)
if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
providers = query.order_by(Provider.provider_type.asc()).all()
for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
return provider
return None
def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = ''
if obfuscated:
return self.obfuscated_token(config)
return config
def obfuscated_token(self, token: str):
return token[:6] + '*' * (len(token) - 8) + token[-2:]
def get_token_type(self):
return str
def get_encrypted_token(self, config: Union[dict | str]):
return self.encrypt_token(config)
def get_decrypted_token(self, token: str):
return self.decrypt_token(token)
def encrypt_token(self, token):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token):
return rsa.decrypt(base64.b64decode(token), self.tenant_id)
@abstractmethod
def get_provider_name(self):
raise NotImplementedError
@abstractmethod
def get_credentials(self, model_id: Optional[str] = None) -> dict:
raise NotImplementedError
@abstractmethod
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
raise NotImplementedError
@abstractmethod
def config_validate(self, config: str):
raise NotImplementedError
class ValidateFailedError(Exception):
description = "Provider Validate failed"
from typing import Optional
from core.llm.provider.base import BaseProvider
from models.provider import ProviderName
class HuggingfaceProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
# todo
return []
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the API credentials for Huggingface as a dictionary, for the given tenant_id.
"""
return {
'huggingface_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.HUGGINGFACEHUB
\ No newline at end of file
from typing import Optional, Union
from core.llm.provider.anthropic_provider import AnthropicProvider
from core.llm.provider.azure_provider import AzureProvider
from core.llm.provider.base import BaseProvider
from core.llm.provider.huggingface_provider import HuggingfaceProvider
from core.llm.provider.openai_provider import OpenAIProvider
from models.provider import Provider
class LLMProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self.init_provider(tenant_id, provider_name)
def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
if provider_name == 'openai':
return OpenAIProvider(tenant_id)
elif provider_name == 'azure_openai':
return AzureProvider(tenant_id)
elif provider_name == 'anthropic':
return AnthropicProvider(tenant_id)
elif provider_name == 'huggingface':
return HuggingfaceProvider(tenant_id)
else:
raise Exception('provider {} not found'.format(provider_name))
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
return self.provider.get_models(model_id)
def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
:param config:
:raises: ValidateFailedError
"""
return self.provider.config_validate(config)
def get_token_type(self):
return self.provider.get_token_type()
def get_encrypted_token(self, config: Union[dict | str]):
return self.provider.get_encrypted_token(config)
import logging
from typing import Optional, Union
import openai
from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName
class OpenAIProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id)
response = openai.Model.list(**credentials)
return [{
'id': model['id'],
'name': model['id'],
} for model in response['data']]
def get_credentials(self, model_id: Optional[str] = None) -> dict:
"""
Returns the credentials for the given tenant_id and provider_name.
"""
return {
'openai_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self):
return ProviderName.OPENAI
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
try:
Moderation(self.get_provider_name().value, config).moderate('test')
except (AuthenticationError, OpenAIError) as ex:
raise ValidateFailedError(str(ex))
except Exception as ex:
logging.exception('OpenAI config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key
from typing import List, Optional, Any, Dict
from httpx import Timeout
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
\ No newline at end of file
import decimal
from typing import Optional
import tiktoken
from core.constant import llm_constant
class TokenCalculator:
@classmethod
def get_num_tokens(cls, model_name: str, text: str):
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(model_name)
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
@classmethod
def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal:
if model_name in llm_constant.models_by_mode['embedding']:
unit_price = llm_constant.model_prices[model_name]['usage']
elif text_type == 'prompt':
unit_price = llm_constant.model_prices[model_name]['prompt']
elif text_type == 'completion':
unit_price = llm_constant.model_prices[model_name]['completion']
else:
raise Exception('Invalid text type')
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
@classmethod
def get_currency(cls, model_name: str):
return llm_constant.model_currency
import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName
from core.llm.provider.base import BaseProvider
class Whisper:
def __init__(self, provider: BaseProvider):
self.provider = provider
if self.provider.get_provider_name() == ProviderName.OPENAI:
self.client = openai.Audio
self.credentials = provider.get_credentials()
@handle_openai_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',
file=file,
api_key=self.credentials.get('openai_api_key'),
api_base=self.credentials.get('openai_api_base'),
api_type=self.credentials.get('openai_api_type'),
api_version=self.credentials.get('openai_api_version'),
)
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper
import logging
from functools import wraps
import openai
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_openai_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper
from typing import Any, List, Dict, Union
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from langchain.schema import get_buffer_string, BaseMessage
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
from core.model_providers.models.llm.base import BaseLLM
from extensions.ext_database import db
from models.model import Conversation, Message
......@@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
model_instance: BaseLLM
memory_key: str = "chat_history"
max_token_limit: int = 2000
message_limit: int = 10
......@@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
messages = list(reversed(messages))
chat_messages: List[BaseMessage] = []
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(HumanMessage(content=message.query))
chat_messages.append(AIMessage(content=message.answer))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:
return chat_messages
return []
# prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
if curr_buffer_length > self.max_token_limit:
pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
curr_buffer_length = self.model_instance.get_num_tokens(chat_messages)
return chat_messages
return to_lc_messages(chat_messages)
@property
def memory_variables(self) -> List[str]:
......
This diff is collapsed.
from typing import Type
from sqlalchemy.exc import IntegrityError
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.rules import provider_rules
from extensions.ext_database import db
from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType
DEFAULT_MODELS = {
ModelType.TEXT_GENERATION.value: {
'provider_name': 'openai',
'model_name': 'gpt-3.5-turbo',
},
ModelType.EMBEDDINGS.value: {
'provider_name': 'openai',
'model_name': 'text-embedding-ada-002',
},
ModelType.SPEECH_TO_TEXT.value: {
'provider_name': 'openai',
'model_name': 'whisper-1',
}
}
class ModelProviderFactory:
@classmethod
def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]:
if provider_name == 'openai':
from core.model_providers.providers.openai_provider import OpenAIProvider
return OpenAIProvider
elif provider_name == 'anthropic':
from core.model_providers.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider
elif provider_name == 'minimax':
from core.model_providers.providers.minimax_provider import MinimaxProvider
return MinimaxProvider
elif provider_name == 'spark':
from core.model_providers.providers.spark_provider import SparkProvider
return SparkProvider
elif provider_name == 'tongyi':
from core.model_providers.providers.tongyi_provider import TongyiProvider
return TongyiProvider
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider
elif provider_name == 'replicate':
from core.model_providers.providers.replicate_provider import ReplicateProvider
return ReplicateProvider
elif provider_name == 'huggingface_hub':
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
return HuggingfaceHubProvider
else:
raise NotImplementedError
@classmethod
def get_provider_names(cls):
"""
Returns a list of provider names.
"""
return list(provider_rules.keys())
@classmethod
def get_provider_rules(cls):
"""
Returns a list of provider rules.
:return:
"""
return provider_rules
@classmethod
def get_provider_rule(cls, provider_name: str):
"""
Returns provider rule.
"""
return provider_rules[provider_name]
@classmethod
def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred model provider.
:param tenant_id: a string representing the ID of the tenant.
:param model_provider_name:
:return:
"""
# get preferred provider
preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name)
if not preferred_provider or not preferred_provider.is_valid:
return None
# init model provider
model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
return model_provider_class(provider=preferred_provider)
@classmethod
def get_preferred_type_by_preferred_model_provider(cls,
tenant_id: str,
model_provider_name: str,
preferred_model_provider: TenantPreferredModelProvider):
"""
get preferred provider type by preferred model provider.
:param model_provider_name:
:param preferred_model_provider:
:return:
"""
if not preferred_model_provider:
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
support_provider_types = model_provider_rules['support_provider_types']
if ProviderType.CUSTOM.value in support_provider_types:
custom_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.is_valid == True
).first()
if custom_provider:
return ProviderType.CUSTOM.value
model_provider = cls.get_model_provider_class(model_provider_name)
if ProviderType.SYSTEM.value in support_provider_types \
and model_provider.is_provider_type_system_supported():
return ProviderType.SYSTEM.value
elif ProviderType.CUSTOM.value in support_provider_types:
return ProviderType.CUSTOM.value
else:
return preferred_model_provider.preferred_provider_type
@classmethod
def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
# get preferred provider type
preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name)
# get providers by preferred provider type
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == preferred_provider_type
).all()
no_system_provider = False
if preferred_provider_type == ProviderType.SYSTEM.value:
quota_type_to_provider_dict = {}
for provider in providers:
quota_type_to_provider_dict[provider.quota_type] = provider
model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
for quota_type_enum in ProviderQuotaType:
quota_type = quota_type_enum.value
if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
and quota_type in quota_type_to_provider_dict.keys():
provider = quota_type_to_provider_dict[quota_type]
if provider.is_valid and provider.quota_limit > provider.quota_used:
return provider
no_system_provider = True
if no_system_provider:
providers = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).all()
if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider:
if providers:
return providers[0]
else:
try:
provider = Provider(
tenant_id=tenant_id,
provider_name=model_provider_name,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(provider)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_provider_name,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
return provider
return None
@classmethod
def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str):
"""
get preferred provider type of tenant.
:param tenant_id:
:param model_provider_name:
:return:
"""
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name == model_provider_name
).first()
return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider)
from abc import ABC
from typing import Any
from core.model_providers.providers.base import BaseModelProvider
class BaseProviderModel(ABC):
_client: Any
_model_provider: BaseModelProvider
def __init__(self, model_provider: BaseModelProvider, client: Any):
self._model_provider = model_provider
self._client = client
@property
def client(self):
return self._client
@property
def model_provider(self):
return self._model_provider
import decimal
import logging
import openai
import tiktoken
from langchain.embeddings import OpenAIEmbeddings
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureOpenAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
self.credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenAIEmbeddings(
deployment=name,
openai_api_type='azure',
openai_api_version=AZURE_OPENAI_API_VERSION,
chunk_size=16,
max_retries=1,
**self.credentials
)
super().__init__(model_provider, client, name)
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name'))
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to Azure OpenAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("Azure OpenAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError('Azure ' + str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
raise LLMAuthorizationError('Azure ' + str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex))
else:
return ex
from abc import abstractmethod
from typing import Any
import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseEmbedding(BaseProviderModel):
name: str
type: ModelType = ModelType.EMBEDDINGS
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def get_num_tokens(self, text: str) -> int:
"""
get num tokens of text.
:param text:
:return:
"""
if len(text) == 0:
return 0
return len(_get_token_ids_default_method(text))
def get_token_price(self, tokens: int):
return 0
def get_currency(self):
return 'USD'
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError
import decimal
import logging
from langchain.embeddings import MiniMaxEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
class MinimaxEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = MiniMaxEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
return decimal.Decimal('0')
def get_currency(self):
return 'RMB'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex
This diff is collapsed.
import decimal
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class ReplicateEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ReplicateEmbeddings(
model=name + ':' + credentials.get('model_version'),
replicate_api_token=credentials.get('replicate_api_token')
)
super().__init__(model_provider, client, name)
def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')
def get_currency(self):
return 'USD'
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}")
else:
return ex
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment