Unverified Commit d069c668 authored by takatost's avatar takatost Committed by GitHub

Model Runtime (#1858)

Co-authored-by: 's avatarStyleZhang <jasonapring2015@outlook.com>
Co-authored-by: 's avatarGarfield Dai <dai.hai@foxmail.com>
Co-authored-by: 's avatarchenhe <guchenhe@gmail.com>
Co-authored-by: 's avatarjyong <jyong@dify.ai>
Co-authored-by: 's avatarJoel <iamjoel007@gmail.com>
Co-authored-by: 's avatarYeuoly <admin@srmxy.cn>
parent e91dd28a
......@@ -7,11 +7,31 @@ on:
push:
branches:
- deploy/dev
- feat/model-runtime
jobs:
test:
runs-on: ubuntu-latest
env:
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
CHATGLM_API_BASE: http://a.abc.com:11451
XINFERENCE_SERVER_URL: http://a.abc.com:11451
XINFERENCE_GENERATION_MODEL_UID: generate
XINFERENCE_CHAT_MODEL_UID: chat
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
XINFERENCE_RERANK_MODEL_UID: rerank
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
MOCK_SWITCH: true
steps:
- name: Checkout code
uses: actions/checkout@v2
......@@ -35,4 +55,4 @@ jobs:
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/unit_tests
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
......@@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
### Provider Integrations
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.
### i18n (Internationalization) Support
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
......
......@@ -4,6 +4,21 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Celery",
"type": "python",
"request": "launch",
"module": "celery",
"justMyCode": true,
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
"envFile": "${workspaceFolder}/.env",
"env": {
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"console": "integratedTerminal"
},
{
"name": "Python: Flask",
"type": "python",
......
......@@ -34,9 +34,6 @@ RUN apt-get update \
COPY --from=base /pkg /usr/local
COPY . /app/api/
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
ENV TRANSFORMERS_OFFLINE true
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
......
......@@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
# if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
import langchain
langchain.verbose = True
import time
import logging
......@@ -18,9 +21,8 @@ import threading
from flask import Flask, request, Response
from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_code_based_extension
ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
from extensions.ext_database import db
from extensions.ext_login import login_manager
......@@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app)
register_commands(app)
hosted.init_app(app)
return app
......@@ -95,6 +95,7 @@ def initialize_extensions(app):
ext_celery.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app)
......@@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if not auth_header:
auth_token = request.args.get('_token')
if not auth_token:
raise Unauthorized('Invalid Authorization token.')
else:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
......
......@@ -12,16 +12,12 @@ import qdrant_client
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
......@@ -327,6 +323,8 @@ def create_qdrant_indexes():
except NotFound:
break
model_manager = ModelManager()
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
......@@ -334,19 +332,23 @@ def create_qdrant_indexes():
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except Exception:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
......
......@@ -87,7 +87,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.34"
self.CURRENT_VERSION = "0.4.0"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
......
......@@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
from .workspace import workspace, members, model_providers, account, tool_providers, models
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
......
......@@ -4,6 +4,10 @@ import logging
from datetime import datetime
from flask_login import current_user
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden
......@@ -13,9 +17,7 @@ from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
......@@ -73,39 +75,41 @@ class AppListApi(Resource):
raise Forbidden()
try:
default_model = ModelFactory.get_text_generation_model(
tenant_id=current_user.current_tenant_id
provider_manager = ProviderManager()
default_model_entity = provider_manager.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None
default_model_entity = None
except Exception as e:
logging.exception(e)
default_model = None
default_model_entity = None
if args['model_config'] is not None:
# validate config
model_config_dict = args['model_config']
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
model_config_dict["model"]["provider"]
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
model_config_dict["model"]["name"] = default_model.name
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=model_config_dict,
mode=args['mode']
app_mode=args['mode']
)
app = App(
......@@ -129,21 +133,27 @@ class AppListApi(Resource):
app_model_config = AppModelConfig(**model_config_template['model_config'])
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
app_model_config.model_dict["provider"]
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.model_provider.provider_name
model_dict['name'] = default_model.name
app_model_config.model = json.dumps(model_dict)
model_manager = ModelManager()
try:
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = model_instance.provider
model_dict['name'] = model_instance.model
app_model_config.model = json.dumps(model_dict)
app.name = args['name']
app.mode = args['mode']
......
......@@ -2,6 +2,8 @@
import logging
from flask import request
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from werkzeug.exceptions import InternalServerError
......@@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
......@@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......
......@@ -5,6 +5,10 @@ from typing import Generator, Union
import flask_login
from flask import Response, stream_with_context
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
......@@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from flask_restful import Resource, reqparse
......@@ -56,7 +58,7 @@ class CompletionMessageApi(Resource):
app_model=app_model,
user=account,
args=args,
from_source='console',
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming,
is_model_config_override=True
)
......@@ -75,8 +77,7 @@ class CompletionMessageApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource):
account = flask_login.current_user
PubHandler.stop(account, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200
......@@ -132,7 +133,7 @@ class ChatMessageApi(Resource):
app_model=app_model,
user=account,
args=args,
from_source='console',
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming,
is_model_config_override=True
)
......@@ -151,8 +152,7 @@ class ChatMessageApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
......@@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource):
account = flask_login.current_user
PubHandler.stop(account, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {'result': 'success'}, 200
......
from flask_login import current_user
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from flask_restful import Resource, reqparse
......@@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.generator.llm_generator import LLMGenerator
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
class RuleGenerateApi(Resource):
......@@ -36,8 +37,7 @@ class RuleGenerateApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
return rules
......
......@@ -14,8 +14,9 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from fields.conversation_fields import message_detail_fields, annotation_fields
from libs.helper import uuid_value
......@@ -208,7 +209,13 @@ class MessageMoreLikeThisApi(Resource):
app_model = _get_app(app_id, 'completion')
try:
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
......@@ -220,8 +227,7 @@ class MessageMoreLikeThisApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -249,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
......@@ -290,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")
......
......@@ -31,7 +31,7 @@ class ModelConfigResource(Resource):
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json,
mode=app.mode
app_mode=app.mode
)
new_app_model_config = AppModelConfig(
......
......@@ -4,6 +4,8 @@ from flask import request, current_app
from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
......@@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.models.entity.model_params import ModelType
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
......@@ -23,7 +24,6 @@ from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
from models.model import UploadFile, ApiToken
from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
def _validate_name(name):
......@@ -55,16 +55,20 @@ class DatasetListApi(Resource):
current_user.current_tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0:
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
......@@ -75,6 +79,7 @@ class DatasetListApi(Resource):
item['embedding_available'] = False
else:
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
......@@ -130,13 +135,20 @@ class DatasetApi(Resource):
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
# check embedding setting
provider_service = ProviderService()
# get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
......
......@@ -2,8 +2,12 @@
from datetime import datetime
from typing import List
from flask import request, current_app
from flask import request
from flask_login import current_user
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc
......@@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.indexing_runner import IndexingRunner
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
......@@ -272,10 +275,12 @@ class DatasetInitApi(Resource):
args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
model_manager = ModelManager()
model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
)
except LLMBadRequestError:
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
......
......@@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.login import login_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
......@@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource):
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
......@@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource):
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
......@@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
......
......@@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
......
......@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
......@@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......
......@@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.helper import uuid_value
from services.completion_service import CompletionService
......@@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource):
app_model=app_model,
user=current_user,
args=args,
from_source='console',
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
......@@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource):
if app_model.mode != 'completion':
raise NotCompletionAppError()
PubHandler.stop(current_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
......@@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource):
app_model=app_model,
user=current_user,
args=args,
from_source='console',
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
......@@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource):
if app_model.mode != 'chat':
raise NotChatAppError()
PubHandler.stop(current_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
......@@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
......
......@@ -5,7 +5,7 @@ from typing import Generator, Union
from flask import stream_with_context, Response
from flask_login import current_user
from flask_restful import reqparse, fields, marshal_with
from flask_restful import reqparse, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound, InternalServerError
......@@ -13,12 +13,14 @@ import services
from controllers.console import api
from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value, TimestampField
from libs.helper import uuid_value
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
......@@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming)
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
......@@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
......@@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")
......
......@@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
......@@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......
......@@ -12,9 +12,10 @@ from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from services.completion_service import CompletionService
......@@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource):
app_model=app_model,
user=current_user,
args=args,
from_source='console',
invoke_from=InvokeFrom.EXPLORE,
streaming=True,
is_model_config_override=True,
)
......@@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource):
class UniversalChatStopApi(UniversalChatResource):
def post(self, universal_app, task_id):
PubHandler.stop(current_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {'result': 'success'}, 200
......@@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
......
......@@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
......@@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")
......
This diff is collapsed.
# -*- coding:utf-8 -*-
from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType
from services.provider_service import ProviderService
class ProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
"""
If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
rest is replaced by * and the last two bits are displayed in plaintext
If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
"""
provider_service = ProviderService()
provider_info_list = provider_service.get_provider_list(tenant_id)
provider_list = [
{
'provider_name': p['provider_name'],
'provider_type': p['provider_type'],
'is_valid': p['is_valid'],
'last_used': p['last_used'],
'is_enabled': p['is_valid'],
**({
'quota_type': p['quota_type'],
'quota_limit': p['quota_limit'],
'quota_used': p['quota_used']
} if p['provider_type'] == ProviderType.SYSTEM.value else {}),
'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
if p['config'] else None
}
for name, provider_info in provider_info_list.items()
for p in provider_info['providers']
]
return provider_list
class ProviderTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
provider_service = ProviderService()
try:
provider_service.save_custom_provider_config(
tenant_id=current_user.current_tenant_id,
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 201
class ProviderTokenValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument('token', required=True, nullable=False, location='json')
args = parser.parse_args()
provider_service = ProviderService()
if provider == 'openai':
args['token'] = {
'openai_api_key': args['token']
}
result = True
error = None
try:
provider_service.custom_provider_config_validate(
provider_name=provider,
config=args['token']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
endpoint='workspaces_current_providers_token') # PUT for updating provider token
api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list
......@@ -34,7 +34,6 @@ tenant_fields = {
'status': fields.String,
'created_at': TimestampField,
'role': fields.String,
'providers': fields.List(fields.Nested(provider_fields)),
'in_trial': fields.Boolean,
'trial_end_reason': fields.String,
'custom_config': fields.Raw(attribute='custom_config'),
......
......@@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
......@@ -49,8 +49,7 @@ class AudioApi(AppApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......
......@@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn
ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError
from controllers.service_api.wraps import AppApiResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from services.completion_service import CompletionService
......@@ -47,7 +48,7 @@ class CompletionApi(AppApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.SERVICE_API,
streaming=streaming,
)
......@@ -65,8 +66,7 @@ class CompletionApi(AppApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource):
if app_model.mode != 'completion':
raise AppUnavailableError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
......@@ -112,7 +112,7 @@ class ChatApi(AppApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.SERVICE_API,
streaming=streaming
)
......@@ -130,8 +130,7 @@ class ChatApi(AppApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource):
if app_model.mode != 'chat':
raise NotChatAppError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
......@@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
......
......@@ -4,11 +4,11 @@ import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from fields.dataset_fields import dataset_detail_fields
from services.dataset_service import DatasetService
from services.provider_service import ProviderService
def _validate_name(name):
......@@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource):
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
......
......@@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError
NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from libs.login import current_user
from core.model_providers.error import ProviderTokenNotInitError
from core.errors.error import ProviderTokenNotInitError
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment
......
......@@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment
......@@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
......@@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource):
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
......@@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource):
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
......
......@@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
......@@ -51,8 +51,7 @@ class AudioApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......
......@@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro
ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from services.completion_service import CompletionService
......@@ -44,7 +45,7 @@ class CompletionApi(WebApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
......@@ -62,8 +63,7 @@ class CompletionApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource):
if app_model.mode != 'completion':
raise NotCompletionAppError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200
......@@ -105,7 +105,7 @@ class ChatApi(WebApiResource):
app_model=app_model,
user=end_user,
args=args,
from_source='api',
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
......@@ -123,8 +123,7 @@ class ChatApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource):
if app_model.mode != 'chat':
raise NotChatAppError()
PubHandler.stop(end_user, task_id)
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {'result': 'success'}, 200
......@@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
......
......@@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi
AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.web.wraps import WebApiResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.entities.application_entities import InvokeFrom
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
......@@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
response = CompletionService.generate_more_like_this(
app_model=app_model,
user=end_user,
message_id=message_id,
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming
)
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
......@@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
......@@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
......@@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")
......
import logging
from typing import Optional, List
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
class AgentLLMCallback(Callback):
def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None:
self.agent_callback = agent_callback
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
:param llm_instance: LLM instance
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_before_invoke(
prompt_messages=prompt_messages
)
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
:param llm_instance: LLM instance
:param chunk: chunk
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
pass
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
:param llm_instance: LLM instance
:param result: result
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_after_invoke(
result=result
)
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback
:param llm_instance: LLM instance
:param ex: exception
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
"""
self.agent_callback.on_llm_error(
error=ex
)
from typing import List
from typing import List, cast
from langchain.schema import BaseMessage
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
return model_instance.get_num_tokens(to_prompt_messages(messages))
def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param llm:
:param model_config:
:param messages:
:return:
"""
llm_max_tokens = model_instance.model_rules.max_tokens.max
completion_max_tokens = model_instance.model_kwargs.max_tokens
used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
......
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
......@@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
......@@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
model_config: ModelConfigEntity
class Config:
"""Configuration for this pydantic object."""
......@@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
raise e
def real_plan(
self,
......@@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.content,
content=result.message.content or "",
additional_kwargs={
'function_call': result.function_call
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
......@@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
......@@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
system_message=system_message,
)
return cls(
model_instance=model_instance,
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
......
from typing import List, Tuple, Any, Union, Sequence, Optional
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
......@@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage,
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.model_manager import ModelInstance
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
model_instance: BaseLLM
summary_model_config: ModelConfigEntity = None
model_config: ModelConfigEntity
agent_llm_callback: Optional[AgentLLMCallback] = None
class Config:
"""Configuration for this pydantic object."""
......@@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
......@@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
system_message=system_message,
)
return cls(
model_instance=model_instance,
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
**kwargs,
)
......@@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
:param query:
:return:
"""
original_max_tokens = self.model_instance.model_kwargs.max_tokens
self.model_instance.model_kwargs.max_tokens = 40
original_max_tokens = 0
for parameter_rule in self.model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
self.model_config.parameters['max_tokens'] = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
callbacks=None
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = result.function_call
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
except Exception as e:
raise e
self.model_instance.model_kwargs.max_tokens = original_max_tokens
self.model_config.parameters['max_tokens'] = original_max_tokens
return True if function_call else False
return True if result.message.tool_calls else False
def plan(
self,
......@@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.content,
content=result.message.content or "",
additional_kwargs={
'function_call': result.function_call
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
......@@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = self.get_message_rest_tokens(
self.model_config,
messages,
**kwargs
)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
......@@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_instance.model_provider.provider_name == 'azure_openai':
model = model_instance.base_model_name
if model_config.provider == 'azure_openai':
model = model_config.model
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_instance.base_model_name
model = model_config.credentials.get("base_model_name")
tiktoken_ = _import_tiktoken()
try:
......
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
......@@ -12,9 +12,7 @@ from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.entities.application_entities import ModelConfigEntity
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......@@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
......@@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
raise e
try:
agent_decision = self.output_parser.parse(full_output)
......@@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
@classmethod
def create_prompt(
cls,
......@@ -182,7 +180,7 @@ Thought: {agent_scratchpad}
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
......@@ -193,7 +191,7 @@ Thought: {agent_scratchpad}
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
......@@ -207,7 +205,7 @@ Thought: {agent_scratchpad}
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
......@@ -221,7 +219,7 @@ Thought: {agent_scratchpad}
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
......@@ -238,10 +236,16 @@ Thought: {agent_scratchpad}
format_instructions=format_instructions,
input_variables=input_variables
)
llm_chain = LLMChain(
model_instance=model_instance,
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
......
......@@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage,
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......@@ -54,7 +55,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
summary_model_config: ModelConfigEntity = None
class Config:
"""Configuration for this pydantic object."""
......@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
along with observatons
callbacks: Callbacks to run.
**kwargs: User inputs.
......@@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
prompt_messages = lc_messages_to_prompt_messages(messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
raise e
try:
agent_decision = self.output_parser.parse(full_output)
......@@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_instance:
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
......@@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
......@@ -229,7 +231,7 @@ Thought: {agent_scratchpad}
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
......@@ -243,7 +245,7 @@ Thought: {agent_scratchpad}
@classmethod
def from_llm_and_tools(
cls,
model_instance: BaseLLM,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
......@@ -253,11 +255,12 @@ Thought: {agent_scratchpad}
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
......@@ -275,9 +278,15 @@ Thought: {agent_scratchpad}
input_variables=input_variables,
)
llm_chain = LLMChain(
model_instance=model_instance,
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
......
......@@ -4,10 +4,10 @@ from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
......@@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
......@@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum):
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_instance: BaseLLM
model_config: ModelConfigEntity
tools: list[BaseTool]
summary_model_instance: BaseLLM = None
memory: Optional[BaseChatMemory] = None
summary_model_config: Optional[ModelConfigEntity] = None
memory: Optional[TokenBufferMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
agent_llm_callback: Optional[AgentLLMCallback] = None
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
......@@ -62,34 +65,42 @@ class AgentExecutor:
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True
......@@ -104,11 +115,11 @@ class AgentExecutor:
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
self.configuration.model_config,
query
)
if not moderation_result:
if moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
......@@ -118,7 +129,6 @@ class AgentExecutor:
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
memory=self.configuration.memory,
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
......@@ -126,8 +136,8 @@ class AgentExecutor:
)
try:
output = agent_executor.run(query)
except LLMError as ex:
output = agent_executor.run(input=query)
except InvokeError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")
......
import json
import logging
from typing import cast
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.app_runner.app_runner import AppRunner
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
from core.application_queue_manager import ApplicationQueueManager
from core.features.agent_runner import AgentRunnerFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_database import db
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
logger = logging.getLogger(__name__)
class AgentApplicationRunner(AppRunner):
"""
Agent Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run agent application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError(f"App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=None,
memory=memory
)
# Create MessageChain
message_chain = self._init_message_chain(
message=message,
query=query
)
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_config=app_orchestration_config.model_config,
message=message,
queue_manager=queue_manager,
message_chain=message_chain
)
# init LLM Callback
agent_llm_callback = AgentLLMCallback(
agent_callback=agent_callback
)
agent_runner = AgentRunnerFeature(
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config=app_orchestration_config,
model_config=app_orchestration_config.model_config,
config=app_orchestration_config.agent,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
agent_llm_callback=agent_llm_callback,
callback=agent_callback,
memory=memory
)
# agent run
result = agent_runner.run(
query=query,
invoke_from=application_generate_entity.invoke_from
)
if result:
self._save_message_chain(
message_chain=message_chain,
output_text=result
)
if (result
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
and app_orchestration_config.prompt_template.simple_prompt_template
):
# Direct output if agent result exists and has pre prompt
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
stream=application_generate_entity.stream,
text=result,
usage=self._get_usage_of_all_agent_thoughts(
model_config=app_orchestration_config.model_config,
message=message
)
)
else:
# As normal LLM run, agent result as context
context = result
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, stop = self.originze_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory
)
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
stop=stop,
stream=application_generate_entity.stream,
user=application_generate_entity.user_id,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
)
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == message.id).all())
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import logging
import threading
import time
from typing import Any, Dict, List, Union, Optional
from typing import Any, Optional, Dict
from flask import Flask, current_app
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage
from flask import current_app, Flask
from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
ImagePromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory
logger = logging.getLogger(__name__)
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, model_instance: BaseLLM,
conversation_message_task: ConversationMessageTask):
self.model_instance = model_instance
self.llm_message = LLMMessage()
self.start_at = None
self.conversation_message_task = conversation_message_task
self.output_moderation_handler = None
self.init_output_moderation()
def init_output_moderation(self):
app_model_config = self.conversation_message_task.app_model_config
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
self.output_moderation_handler = OutputModerationHandler(
tenant_id=self.conversation_message_task.tenant_id,
app_id=self.conversation_message_task.app.id,
rule=ModerationRule(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config")
),
on_message_replace_func=self.conversation_message_task.on_message_replace
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
real_prompts = []
for message in messages[0]:
if message.type == 'human':
role = 'user'
elif message.type == 'ai':
role = 'assistant'
else:
role = 'system'
real_prompts.append({
"role": role,
"text": message.content,
"files": [{
"type": file.type.value,
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
completion=response.generations[0][0].text,
public_event=True if self.conversation_message_task.streaming else False
)
else:
self.llm_message.completion = response.generations[0][0].text
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(self.llm_message.completion)
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
ex = ConversationTaskInterruptException()
self.on_llm_error(error=ex)
raise ex
try:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
if self.output_moderation_handler:
self.output_moderation_handler.append_new_token(token)
except ConversationTaskStoppedException as ex:
self.on_llm_error(error=ex)
raise ex
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
if isinstance(error, ConversationTaskInterruptException):
self.llm_message.completion = self.output_moderation_handler.get_final_output()
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message)
else:
logging.debug("on_llm_error: %s", error)
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
......@@ -279,6 +133,6 @@ class OutputModerationHandler(BaseModel):
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logging.error("Moderation Output error: %s", e)
logger.error("Moderation Output error: %s", e)
return None
This diff is collapsed.
import queue
import time
from typing import Generator, Any
from sqlalchemy.orm import DeclarativeMeta
from core.entities.application_entities import InvokeFrom
from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \
QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \
QueueMessageEvent, QueueMessage, AnnotationReplyEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from extensions.ext_redis import redis_client
from models.model import MessageAgentThought
class ApplicationQueueManager:
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
conversation_id: str,
app_mode: str,
message_id: str) -> None:
if not user_id:
raise ValueError("user is required")
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
q = queue.Queue()
self._q = q
def listen(self) -> Generator:
"""
Listen to queue
:return:
"""
# wait for 10 minutes to stop listen
listen_timeout = 600
start_time = time.time()
last_ping_time = 0
while True:
try:
message = self._q.get(timeout=1)
if message is None:
break
yield message
except queue.Empty:
continue
finally:
elapsed_time = time.time() - start_time
if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
self.stop_listen()
if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent())
last_ping_time = elapsed_time // 10
def stop_listen(self) -> None:
"""
Stop listen to queue
:return:
"""
self._q.put(None)
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
"""
Publish chunk message to channel
:param chunk: chunk
:return:
"""
self.publish(QueueMessageEvent(
chunk=chunk
))
def publish_message_replace(self, text: str) -> None:
"""
Publish message replace
:param text: text
:return:
"""
self.publish(QueueMessageReplaceEvent(
text=text
))
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
"""
Publish retriever resources
:return:
"""
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
def publish_annotation_reply(self, message_annotation_id: str) -> None:
"""
Publish annotation reply
:param message_annotation_id: message annotation id
:return:
"""
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
def publish_message_end(self, llm_result: LLMResult) -> None:
"""
Publish message end
:param llm_result: llm result
:return:
"""
self.publish(QueueMessageEndEvent(llm_result=llm_result))
self.stop_listen()
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
"""
Publish agent thought
:param message_agent_thought: message agent thought
:return:
"""
self.publish(QueueAgentThoughtEvent(
agent_thought_id=message_agent_thought.id
))
def publish_error(self, e) -> None:
"""
Publish error
:param e: error
:return:
"""
self.publish(QueueErrorEvent(
error=e
))
self.stop_listen()
def publish(self, event: AppQueueEvent) -> None:
"""
Publish event to queue
:param event:
:return:
"""
self._check_for_sqlalchemy_models(event.dict())
message = QueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
)
self._q.put(message)
if isinstance(event, QueueStopEvent):
self.stop_listen()
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
"""
Set task stop flag
:return:
"""
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
if result != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool:
"""
Check if task is stopped
:return:
"""
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key)
if result is not None:
redis_client.delete(stopped_cache_key)
return True
return False
@classmethod
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
"""
Generate task belong cache key
:param task_id: task id
:return:
"""
return f"generate_task_belong:{task_id}"
@classmethod
def _generate_stopped_cache_key(cls, task_id: str) -> str:
"""
Generate stopped cache key
:param task_id: task id
:return:
"""
return f"generate_task_stopped:{task_id}"
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed.")
class ConversationTaskStoppedException(Exception):
pass
import json
import logging
from json import JSONDecodeError
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.queries = []
self.conversation_message_task = conversation_message_task
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return True
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return True
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return False
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
tool_name: str = serialized.get('name')
dataset_id = tool_name.removeprefix('dataset-')
try:
input_dict = json.loads(input_str.replace("'", "\""))
query = input_dict.get('query')
except JSONDecodeError:
query = input_str
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.debug("Dataset tool on_llm_error: %s", error)
from pydantic import BaseModel
class ChainResult(BaseModel):
type: str = None
prompt: dict = None
completion: dict = None
status: str = 'chain_started'
completed: bool = False
started_at: float = None
completed_at: float = None
agent_result: dict = None
"""only when type is 'AgentExecutor'"""
from pydantic import BaseModel
class DatasetQueryObj(BaseModel):
dataset_id: str = None
query: str = None
from pydantic import BaseModel
class LLMMessage(BaseModel):
prompt: str = ''
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0
......@@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
try:
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
except ValueError:
thought = ''
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment