Unverified Commit 4fdb3777 authored by John Wang's avatar John Wang Committed by GitHub

feat: universal chat in explore (#649)

Co-authored-by: 's avatarStyleZhang <jasonapring2015@outlook.com>
parent 94b54b7c
...@@ -19,7 +19,7 @@ def check_file_for_chinese_comments(file_path): ...@@ -19,7 +19,7 @@ def check_file_for_chinese_comments(file_path):
def main(): def main():
has_chinese = False has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py'] excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
for root, _, files in os.walk("."): for root, _, files in os.walk("."):
for file in files: for file in files:
......
...@@ -22,7 +22,7 @@ from extensions.ext_database import db ...@@ -22,7 +22,7 @@ from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
# DO NOT REMOVE BELOW # DO NOT REMOVE BELOW
from models import model, account, dataset, web, task, source from models import model, account, dataset, web, task, source, tool
from events import event_handlers from events import event_handlers
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE
......
...@@ -18,7 +18,10 @@ from .auth import login, oauth, data_source_oauth, activate ...@@ -18,7 +18,10 @@ from .auth import login, oauth, data_source_oauth, activate
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source
# Import workspace controllers # Import workspace controllers
from .workspace import workspace, members, providers, account from .workspace import workspace, members, model_providers, account, tool_providers
# Import explore controllers # Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio
...@@ -24,6 +24,7 @@ model_config_fields = { ...@@ -24,6 +24,7 @@ model_config_fields = {
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'), 'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'), 'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'), 'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String, 'pre_prompt': fields.String,
...@@ -96,7 +97,8 @@ class AppListApi(Resource): ...@@ -96,7 +97,8 @@ class AppListApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_models = db.paginate( app_models = db.paginate(
db.select(App).where(App.tenant_id == current_user.current_tenant_id).order_by(App.created_at.desc()), db.select(App).where(App.tenant_id == current_user.current_tenant_id,
App.is_universal == False).order_by(App.created_at.desc()),
page=args['page'], page=args['page'],
per_page=args['limit'], per_page=args['limit'],
error_out=False) error_out=False)
...@@ -147,6 +149,7 @@ class AppListApi(Resource): ...@@ -147,6 +149,7 @@ class AppListApi(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']), speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']), more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']), model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']), user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'], pre_prompt=model_configuration['pre_prompt'],
...@@ -438,6 +441,7 @@ class AppCopy(Resource): ...@@ -438,6 +441,7 @@ class AppCopy(Resource):
suggested_questions_after_answer=app_config.suggested_questions_after_answer, suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text, speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this, more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model, model=app_config.model,
user_input_form=app_config.user_input_form, user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt, pre_prompt=app_config.pre_prompt,
......
...@@ -163,7 +163,7 @@ class CompletionConversationApi(Resource): ...@@ -163,7 +163,7 @@ class CompletionConversationApi(Resource):
if args['end']: if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
...@@ -322,7 +322,7 @@ class ChatConversationApi(Resource): ...@@ -322,7 +322,7 @@ class ChatConversationApi(Resource):
if args['end']: if args['end']:
end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
......
...@@ -43,6 +43,7 @@ class ModelConfigResource(Resource): ...@@ -43,6 +43,7 @@ class ModelConfigResource(Resource):
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']), speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']), more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']), model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']), user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'], pre_prompt=model_configuration['pre_prompt'],
......
...@@ -65,7 +65,10 @@ class ConversationApi(InstalledAppResource): ...@@ -65,7 +65,10 @@ class ConversationApi(InstalledAppResource):
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
try:
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user) WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204 return {"result": "success"}, 204
......
...@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields ...@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
from controllers.console import api from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import InstalledApp
class AppParameterApi(InstalledAppResource): class AppParameterApi(InstalledAppResource):
"""Resource for app variables.""" """Resource for app variables."""
...@@ -27,16 +31,17 @@ class AppParameterApi(InstalledAppResource): ...@@ -27,16 +31,17 @@ class AppParameterApi(InstalledAppResource):
} }
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, installed_app): def get(self, installed_app: InstalledApp):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1')
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }
......
# -*- coding:utf-8 -*-
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
from models.model import AppModelConfig
class UniversalChatAudioApi(UniversalChatResource):
def post(self, universal_app):
app_model = universal_app
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(UniversalChatAudioApi, '/universal-chat/audio-to-text')
\ No newline at end of file
import json
import logging
from typing import Generator, Union
from flask import Response, stream_with_context
from flask_login import current_user
from flask_restful import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.constant import llm_constant
from core.conversation_message_task import PubHandler
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError
from libs.helper import uuid_value
from services.completion_service import CompletionService
class UniversalChatApi(UniversalChatResource):
def post(self, universal_app):
app_model = universal_app
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json')
args = parser.parse_args()
app_model_config = app_model.app_model_config
# update app model config
args['model_config'] = app_model_config.to_dict()
args['model_config']['model']['name'] = args['model']
if not llm_constant.models[args['model']]:
raise ValueError("Model not exists.")
args['model_config']['model']['provider'] = llm_constant.models[args['model']]
args['model_config']['agent_mode']['tools'] = args['tools']
args['inputs'] = {}
del args['model']
del args['tools']
try:
response = CompletionService.completion(
app_model=app_model,
user=current_user,
args=args,
from_source='console',
streaming=True,
is_model_config_override=True,
)
return compact_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class UniversalChatStopApi(UniversalChatResource):
def post(self, universal_app, task_id):
PubHandler.stop(current_user, task_id)
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
try:
for chunk in response:
yield chunk
except services.errors.conversation.ConversationNotExistsError:
yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n"
except services.errors.conversation.ConversationCompletedError:
yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n"
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
logging.exception("internal server error.")
yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
api.add_resource(UniversalChatApi, '/universal-chat/messages')
api.add_resource(UniversalChatStopApi, '/universal-chat/messages/<string:task_id>/stop')
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import fields, reqparse, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource
from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField,
'model_config': fields.Raw,
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class UniversalChatConversationListApi(UniversalChatResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, universal_app):
app_model = universal_app
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
args = parser.parse_args()
pinned = None
if 'pinned' in args and args['pinned'] is not None:
pinned = True if args['pinned'] == 'true' else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=current_user,
last_id=args['last_id'],
limit=args['limit'],
pinned=pinned
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class UniversalChatConversationApi(UniversalChatResource):
def delete(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
try:
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204
class UniversalChatConversationRenameApi(UniversalChatResource):
@marshal_with(conversation_fields)
def post(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class UniversalChatConversationPinApi(UniversalChatResource):
def patch(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
try:
WebConversationService.pin(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}
class UniversalChatConversationUnPinApi(UniversalChatResource):
def patch(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}
api.add_resource(UniversalChatConversationRenameApi, '/universal-chat/conversations/<uuid:c_id>/name')
api.add_resource(UniversalChatConversationListApi, '/universal-chat/conversations')
api.add_resource(UniversalChatConversationApi, '/universal-chat/conversations/<uuid:c_id>')
api.add_resource(UniversalChatConversationPinApi, '/universal-chat/conversations/<uuid:c_id>/pin')
api.add_resource(UniversalChatConversationUnPinApi, '/universal-chat/conversations/<uuid:c_id>/unpin')
# -*- coding:utf-8 -*-
import logging
from flask_login import current_user
from flask_restful import reqparse, fields, marshal_with
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound, InternalServerError
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
class UniversalChatMessageListApi(UniversalChatResource):
feedback_fields = {
'rating': fields.String
}
agent_thought_fields = {
'id': fields.String,
'chain_id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'thought': fields.String,
'tool': fields.String,
'tool_input': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, universal_app):
app_model = universal_app
parser = reqparse.RequestParser()
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
try:
return MessageService.pagination_by_first_id(app_model, current_user,
args['conversation_id'], args['first_id'], args['limit'])
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.message.FirstMessageNotExistsError:
raise NotFound("First Message Not Exists.")
class UniversalChatMessageFeedbackApi(UniversalChatResource):
def post(self, universal_app, message_id):
app_model = universal_app
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success'}
class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
def get(self, universal_app, message_id):
app_model = universal_app
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=current_user,
message_id=message_id
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
return {'data': questions}
api.add_resource(UniversalChatMessageListApi, '/universal-chat/messages')
api.add_resource(UniversalChatMessageFeedbackApi, '/universal-chat/messages/<uuid:message_id>/feedbacks')
api.add_resource(UniversalChatMessageSuggestedQuestionApi, '/universal-chat/messages/<uuid:message_id>/suggested-questions')
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, fields
from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
class UniversalChatParameterApi(UniversalChatResource):
"""Resource for app variables."""
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
}
@marshal_with(parameters_fields)
def get(self, universal_app: App):
"""Retrieve app parameters."""
app_model = universal_app
app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1')
return {
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
}
api.add_resource(UniversalChatParameterApi, '/universal-chat/parameters')
import json
from functools import wraps
from flask_login import login_required, current_user
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from models.model import App, AppModelConfig
def universal_chat_app_required(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
# get universal chat app
universal_app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id,
App.is_universal == True
).first()
if universal_app is None:
# create universal app if not exists
universal_app = App(
tenant_id=current_user.current_tenant_id,
name='Universal Chat',
mode='chat',
is_universal=True,
icon='',
icon_background='',
api_rpm=0,
api_rph=0,
enable_site=False,
enable_api=False,
status='normal'
)
db.session.add(universal_app)
db.session.flush()
app_model_config = AppModelConfig(
provider="",
model_id="",
configs={},
opening_statement='',
suggested_questions=json.dumps([]),
suggested_questions_after_answer=json.dumps({'enabled': True}),
speech_to_text=json.dumps({'enabled': True}),
more_like_this=None,
sensitive_word_avoidance=None,
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-16k",
"completion_params": {
"max_tokens": 800,
"temperature": 0.8,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
}),
user_input_form=json.dumps([]),
pre_prompt='',
agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
)
app_model_config.app_id = universal_app.id
db.session.add(app_model_config)
db.session.flush()
universal_app.app_model_config_id = app_model_config.id
db.session.commit()
return view(universal_app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
class UniversalChatResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [universal_chat_app_required, account_initialization_required, login_required, setup_required]
import json
from flask_login import login_required, current_user
from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.provider.tool_provider_service import ToolProviderService
from extensions.ext_database import db
from models.tool import ToolProvider, ToolProviderName
class ToolProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
tool_credential_dict = {}
for tool_name in ToolProviderName:
tool_credential_dict[tool_name.value] = {
'tool_name': tool_name.value,
'is_enabled': False,
'credentials': None
}
tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
for p in tool_providers:
if p.is_enabled:
tool_credential_dict[p.tool_name] = {
'tool_name': p.tool_name,
'is_enabled': p.is_enabled,
'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
}
return list(tool_credential_dict.values())
class ToolProviderCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ToolProviderName]:
abort(404)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
f'current_role is {current_user.current_tenant.current_role}')
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
tool_provider_service = ToolProviderService(tenant_id, provider)
try:
tool_provider_service.credentials_validate(args['credentials'])
except ToolValidateFailedError as ex:
raise ValueError(str(ex))
encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
tenant = current_user.current_tenant
tool_provider_model = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == tenant.id,
ToolProvider.tool_name == provider,
).first()
# Only allow updating token for CUSTOM provider type
if tool_provider_model:
tool_provider_model.encrypted_credentials = encrypted_credentials
tool_provider_model.is_enabled = True
else:
tool_provider_model = ToolProvider(
tenant_id=tenant.id,
tool_name=provider,
encrypted_credentials=encrypted_credentials,
is_enabled=True
)
db.session.add(tool_provider_model)
db.session.commit()
return {'result': 'success'}, 201
class ToolProviderCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
if provider not in [p.value for p in ToolProviderName]:
abort(404)
parser = reqparse.RequestParser()
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
result = True
error = None
tenant_id = current_user.current_tenant_id
tool_provider_service = ToolProviderService(tenant_id, provider)
try:
tool_provider_service.credentials_validate(args['credentials'])
except ToolValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
api.add_resource(ToolProviderCredentialsValidateApi,
'/workspaces/current/tool-providers/<provider>/credentials-validate')
...@@ -4,6 +4,10 @@ from flask_restful import fields, marshal_with ...@@ -4,6 +4,10 @@ from flask_restful import fields, marshal_with
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
class AppParameterApi(AppApiResource): class AppParameterApi(AppApiResource):
"""Resource for app variables.""" """Resource for app variables."""
...@@ -28,15 +32,16 @@ class AppParameterApi(AppApiResource): ...@@ -28,15 +32,16 @@ class AppParameterApi(AppApiResource):
} }
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }
......
...@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields ...@@ -4,6 +4,10 @@ from flask_restful import marshal_with, fields
from controllers.web import api from controllers.web import api
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.llm.llm_builder import LLMBuilder
from models.provider import ProviderName
from models.model import App
class AppParameterApi(WebApiResource): class AppParameterApi(WebApiResource):
"""Resource for app variables.""" """Resource for app variables."""
...@@ -27,15 +31,16 @@ class AppParameterApi(WebApiResource): ...@@ -27,15 +31,16 @@ class AppParameterApi(WebApiResource):
} }
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1')
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False },
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }
......
...@@ -62,7 +62,10 @@ class ConversationApi(WebApiResource): ...@@ -62,7 +62,10 @@ class ConversationApi(WebApiResource):
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
try:
ConversationService.delete(app_model, conversation_id, end_user) ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, end_user) WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204 return {"result": "success"}, 204
......
from typing import cast, List
from langchain import OpenAI
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage
from core.constant import llm_constant
class CalcTokenMixin:
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
llm = cast(ChatOpenAI, llm)
return llm.get_num_tokens_from_messages(messages)
def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param llm:
:param messages:
:return:
"""
llm = cast(ChatOpenAI, llm)
llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
completion_max_tokens = llm.max_tokens
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage
from langchain.tools import BaseTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
"""
An Multi Dataset Retrieve Agent driven by Router.
"""
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
if len(self.tools) == 0:
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
return super().plan(intermediate_steps, callbacks, **kwargs)
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
raise NotImplementedError()
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
llm.model_name = 'gpt-3.5-turbo'
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
**kwargs,
)
from datetime import datetime
from typing import List, Tuple, Any, Union, Sequence, Optional
import pytz
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
_format_intermediate_steps
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
# get current time
current_time = datetime.now()
current_timezone = pytz.timezone('UTC')
current_time = current_timezone.localize(current_time)
return SystemMessage(content="You are a helpful AI assistant.\n"
"Current time: {}\n"
"Respond directly if appropriate.".format(
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
summary_handler = SummarizerMixin(llm=self.summary_llm)
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, llm)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
from datetime import datetime
from typing import List, Tuple, Any, Union, Sequence, Optional
import pytz
from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
# get current time
current_time = datetime.now()
current_timezone = pytz.timezone('UTC')
current_time = current_timezone.localize(current_time)
return SystemMessage(content="You are a helpful AI assistant.\n"
"Current time: {}\n"
"Respond directly if appropriate.".format(
current_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")))
import json
import re
from typing import Union
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \
logger
from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
action_match = re.search(r"```(.*?)\n(.*?)```?", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
logger.warning("Got multiple action responses: %s", response)
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
response["action"], response.get("action_input", {}), text
)
else:
return AgentFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
import re
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
self.moving_summary_index = len(intermediate_steps)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
if 'chat_history' in kwargs:
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
**kwargs,
)
from typing import Optional
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.llm.llm_builder import LLMBuilder
class AgentBuilder:
@classmethod
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name,
temperature=0,
max_tokens=1024,
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
for tool in tools:
tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
]
prompt = cls.build_agent_prompt_template(
tools=tools,
memory=memory,
)
agent_llm_chain = LLMChain(
llm=llm,
prompt=prompt,
)
agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
agent_callback_manager = CallbackManager(
[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
agent_chain = AgentExecutor.from_agent_and_tools(
tools=tools,
agent=agent,
memory=memory,
callbacks=agent_callback_manager,
max_iterations=6,
early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
)
return agent_chain
@classmethod
def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
if memory:
prompt = ConversationalAgent.create_prompt(
tools=tools,
)
else:
prompt = ZeroShotAgent.create_prompt(
tools=tools,
)
return prompt
@classmethod
def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
if memory:
agent = ConversationalAgent(
llm_chain=agent_llm_chain
)
else:
agent = ZeroShotAgent(
llm_chain=agent_llm_chain
)
return agent
import enum
import logging
from typing import Union, Optional
from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
llm: BaseLanguageModel
tools: list[BaseTool]
summary_llm: BaseLanguageModel
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
max_execution_time: Optional[float] = None
early_stopping_method: str = "generate"
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
class AgentExecuteResult(BaseModel):
strategy: PlanningStrategy
output: Optional[str]
configuration: AgentConfiguration
class AgentExecutor:
def __init__(self, configuration: AgentConfiguration):
self.configuration = configuration
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_llm,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
)
else:
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
return agent
def should_use_agent(self, query: str) -> bool:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
memory=self.configuration.memory,
max_iterations=self.configuration.max_iterations,
max_execution_time=self.configuration.max_execution_time,
early_stopping_method=self.configuration.early_stopping_method,
callbacks=self.configuration.callbacks
)
try:
output = agent_executor.run(query)
except Exception:
logging.exception("agent_executor run failed")
output = None
return AgentExecuteResult(
output=output,
strategy=self.configuration.strategy,
configuration=self.configuration
)
import json
import logging import logging
import time import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
...@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -20,6 +22,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
self.current_chain = None self.current_chain = None
@property @property
...@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -29,6 +32,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def clear_agent_loops(self) -> None: def clear_agent_loops(self) -> None:
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
@property @property
def always_verbose(self) -> bool: def always_verbose(self) -> bool:
...@@ -61,8 +65,20 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -61,8 +65,20 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
# kwargs={} # kwargs={}
if self._current_loop and self._current_loop.status == 'llm_started': if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end' self._current_loop.status = 'llm_end'
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message
if 'function_call' in completion_message.additional_kwargs:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.additional_kwargs['function_call']})
else:
self._current_loop.completion = response.generations[0][0].text self._current_loop.completion = response.generations[0][0].text
else:
self._current_loop.completion = completion_generation.text
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_error( def on_llm_error(
...@@ -71,6 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -71,6 +87,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error) logging.error(error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
def on_tool_start( def on_tool_start(
self, self,
...@@ -89,8 +106,15 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -89,8 +106,15 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
) -> Any: ) -> Any:
"""Run on agent action.""" """Run on agent action."""
tool = action.tool tool = action.tool
tool_input = action.tool_input tool_input = json.dumps({"query": action.tool_input}
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 if isinstance(action.tool_input, str) else action.tool_input)
completion = None
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
thought = action.log.strip()
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
else:
action_name_position = action.log.index("Action:") if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else '' thought = action.log[:action_name_position].strip() if action.log else ''
if self._current_loop and self._current_loop.status == 'llm_end': if self._current_loop and self._current_loop.status == 'llm_end':
...@@ -98,6 +122,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -98,6 +122,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.thought = thought self._current_loop.thought = thought
self._current_loop.tool_name = tool self._current_loop.tool_name = tool
self._current_loop.tool_input = tool_input self._current_loop.tool_input = tool_input
if completion is not None:
self._current_loop.completion = completion
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
def on_tool_end( def on_tool_end(
self, self,
...@@ -120,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -120,10 +151,13 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed_at = time.perf_counter() self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop) self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
def on_tool_error( def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
...@@ -132,6 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -132,6 +166,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
logging.error(error) logging.error(error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end.""" """Run on agent end."""
...@@ -141,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -141,10 +176,18 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completed = True self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter() self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]'
self._message_agent_thought = self.conversation_message_task.on_agent_start(
self.current_chain,
self._current_loop
)
self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop) self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_name, self._current_loop
)
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
self._current_loop = None self._current_loop = None
self._message_agent_thought = None
elif not self._current_loop and self._agent_loops: elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish' self._agent_loops[-1].status = 'agent_finish'
import json
import logging import logging
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
...@@ -43,9 +44,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ...@@ -43,9 +44,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str, input_str: str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
tool_name = serialized.get('name') # tool_name = serialized.get('name')
dataset_id = tool_name[len("dataset-"):] input_dict = json.loads(input_str.replace("'", "\""))
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str)) dataset_id = input_dict.get('dataset_id')
query = input_dict.get('query')
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end( def on_tool_end(
self, self,
......
...@@ -10,9 +10,9 @@ class AgentLoop(BaseModel): ...@@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
tool_output: str = None tool_output: str = None
prompt: str = None prompt: str = None
prompt_tokens: int = None prompt_tokens: int = 0
completion: str = None completion: str = None
completion_tokens: int = None completion_tokens: int = 0
latency: float = None latency: float = None
......
import logging import logging
import time import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel
from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler): class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def __init__(self, llm: BaseLanguageModel,
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
self.llm = llm self.llm = llm
self.llm_message = LLMMessage() self.llm_message = LLMMessage()
......
...@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -20,15 +20,13 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self._current_chain_result = None self._current_chain_result = None
self._current_chain_message = None self._current_chain_message = None
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler( self.agent_callback = None
llm_constant.agent_model_name,
conversation_message_task
)
def clear_chain_results(self) -> None: def clear_chain_results(self) -> None:
self._current_chain_result = None self._current_chain_result = None
self._current_chain_message = None self._current_chain_message = None
self.agent_loop_gather_callback_handler.current_chain = None if self.agent_callback:
self.agent_callback.current_chain = None
@property @property
def always_verbose(self) -> bool: def always_verbose(self) -> bool:
...@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -58,7 +56,8 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
started_at=time.perf_counter() started_at=time.perf_counter()
) )
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message if self.agent_callback:
self.agent_callback.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
......
from typing import Optional
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
class ChainBuilder:
@classmethod
def to_tool_chain(cls, tool, **kwargs) -> ToolChain:
return ToolChain(
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[
SensitiveWordAvoidanceChain]:
sensitive_words = tool_config.get("words", "")
if tool_config.get("enabled", False) \
and sensitive_words:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)
return None
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown
class Route(NamedTuple):
destination: Optional[str]
next_inputs: Dict[str, Any]
class LLMRouterChain(Chain):
"""A router chain that uses an LLM chain to perform routing."""
llm_chain: LLMChain
"""LLM chain used to perform routing"""
@root_validator()
def validate_prompt(cls, values: dict) -> dict:
prompt = values["llm_chain"].prompt
if prompt.output_parser is None:
raise ValueError(
"LLMRouterChain requires base llm_chain prompt to have an output"
" parser that converts LLM text output to a dictionary with keys"
" 'destination' and 'next_inputs'. Received a prompt with no output"
" parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return self.llm_chain.input_keys
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict):
raise ValueError
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],
self.llm_chain.predict_and_parse(**inputs),
)
return output
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
) -> LLMRouterChain:
"""Convenience constructor."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
@property
def output_keys(self) -> List[str]:
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any]) -> Route:
result = self(inputs)
return Route(result["destination"], result["next_inputs"])
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
"""Parser for output of router chain int he multi-prompt chain."""
default_destination: str = "DEFAULT"
next_inputs_type: Type = str
next_inputs_inner_key: str = "input"
def parse(self, text: str) -> Dict[str, Any]:
try:
expected_keys = ["destination", "next_inputs"]
parsed = parse_and_check_json_markdown(text, expected_keys)
if not isinstance(parsed["destination"], str):
raise ValueError("Expected 'destination' to be a string.")
if not isinstance(parsed["next_inputs"], self.next_inputs_type):
raise ValueError(
f"Expected 'next_inputs' to be {self.next_inputs_type}."
)
parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
if (
parsed["destination"].strip().lower()
== self.default_destination.lower()
):
parsed["destination"] = None
else:
parsed["destination"] = parsed["destination"].strip()
return parsed
except Exception as e:
raise OutputParserException(
f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
)
from typing import Optional, List, cast
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask
from extensions.ext_database import db
from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
chains = []
chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
# agent mode
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
chains += tool_chains
if chains_output_key:
final_output_key = chains_output_key
if len(chains) == 0:
return None
for chain in chains:
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
chains=chains,
input_variables=[first_input_key],
output_variables=[final_output_key],
memory=memory, # only for use the memory prompt input key
)
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
if agent_mode and agent_mode.get('enabled'):
tools = agent_mode.get('tools', [])
pre_fixed_chains = []
# agent_tools = []
datasets = []
for tool in tools:
tool_type = list(tool.keys())[0]
tool_config = list(tool.values())[0]
if tool_type == 'sensitive-word-avoidance':
chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
if chain:
pre_fixed_chains.append(chain)
elif tool_type == "dataset":
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == tool_config.get("id")
).first()
if dataset:
datasets.append(dataset)
# add pre-fixed chains
chains += pre_fixed_chains
if len(datasets) > 0:
# tool to chain
multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)
final_output_key = cls.get_chains_output_key(chains)
return chains, final_output_key
@classmethod
def get_chains_output_key(cls, chains: List[Chain]):
if len(chains) > 0:
return chains[-1].output_keys[0]
return None
import math
import re
from typing import Mapping, List, Dict, Any, Optional
from langchain import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset, DatasetProcessRule
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
what the prompt is best suited for. You may also revise the original input if you \
think that revising it will ultimately lead to a better response from the language \
model.
<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like, \
no any other string out of markdown code snippet:
```json
{{{{
"destination": string \\ name of the prompt to use or "DEFAULT"
"next_inputs": string \\ a potentially modified version of the original input
}}}}
```
REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any \
modifications are needed.
<< CANDIDATE PROMPTS >>
{destinations}
<< INPUT >>
{{input}}
<< OUTPUT >>
"""
class MultiDatasetRouterChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
"""
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
return ["text"]
@classmethod
def from_datasets(
cls,
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {}
for dataset in datasets:
# fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset,
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
dataset_tools[str(dataset.id)] = dataset_tool
return cls(
router_chain=router_chain,
dataset_tools=dataset_tools,
**kwargs,
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if len(self.dataset_tools) == 0:
return {"text": ''}
elif len(self.dataset_tools) == 1:
return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
route = self.router_chain.route(inputs)
destination = ''
if route.destination:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, route.destination, re.IGNORECASE)
if match:
destination = match.group()
if not destination:
return {"text": ''}
elif destination in self.dataset_tools:
return {"text": self.dataset_tools[destination].run(
route.next_inputs['input']
)}
else:
raise ValueError(
f"Received invalid destination chain name '{destination}'"
)
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.tools import BaseTool
class ToolChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
tool: BaseTool
@property
def _chain_type(self) -> str:
return "tool_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)
return {self.output_key: output}
This diff is collapsed.
...@@ -52,7 +52,7 @@ class ConversationMessageTask: ...@@ -52,7 +52,7 @@ class ConversationMessageTask:
message=self.message, message=self.message,
conversation=self.conversation, conversation=self.conversation,
chain_pub=False, # disabled currently chain_pub=False, # disabled currently
agent_thought_pub=False # disabled currently agent_thought_pub=True
) )
def init(self): def init(self):
...@@ -69,6 +69,7 @@ class ConversationMessageTask: ...@@ -69,6 +69,7 @@ class ConversationMessageTask:
"suggested_questions": self.app_model_config.suggested_questions_list, "suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict, "suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict, "more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list, "user_input_form": self.app_model_config.user_input_form_list,
} }
...@@ -207,7 +208,28 @@ class ConversationMessageTask: ...@@ -207,7 +208,28 @@ class ConversationMessageTask:
self._pub_handler.pub_chain(message_chain) self._pub_handler.pub_chain(message_chain)
def on_agent_end(self, message_chain: MessageChain, agent_model_name: str, def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=message_chain.id,
position=agent_loop.position,
thought=agent_loop.thought,
tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input,
message=agent_loop.prompt,
answer=agent_loop.completion,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_thought)
db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt'] agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion'] agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
...@@ -222,34 +244,18 @@ class ConversationMessageTask: ...@@ -222,34 +244,18 @@ class ConversationMessageTask:
agent_answer_unit_price agent_answer_unit_price
) )
message_agent_loop = MessageAgentThought( message_agent_thought.observation = agent_loop.tool_output
message_id=self.message.id, message_agent_thought.tool_process_data = '' # currently not support
message_chain_id=message_chain.id, message_agent_thought.message_token = loop_message_tokens
position=agent_loop.position, message_agent_thought.message_unit_price = agent_message_unit_price
thought=agent_loop.thought, message_agent_thought.answer_token = loop_answer_tokens
tool=agent_loop.tool_name, message_agent_thought.answer_unit_price = agent_answer_unit_price
tool_input=agent_loop.tool_input, message_agent_thought.latency = agent_loop.latency
observation=agent_loop.tool_output, message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
tool_process_data='', # currently not support message_agent_thought.total_price = loop_total_price
message=agent_loop.prompt, message_agent_thought.currency = llm_constant.model_currency
message_token=loop_message_tokens,
message_unit_price=agent_message_unit_price,
answer=agent_loop.completion,
answer_token=loop_answer_tokens,
answer_unit_price=agent_answer_unit_price,
latency=agent_loop.latency,
tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
total_price=loop_total_price,
currency=llm_constant.model_currency,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_agent_loop)
db.session.flush() db.session.flush()
self._pub_handler.pub_agent_thought(message_agent_loop)
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
dataset_id=dataset_query_obj.dataset_id, dataset_id=dataset_query_obj.dataset_id,
...@@ -346,16 +352,14 @@ class PubHandler: ...@@ -346,16 +352,14 @@ class PubHandler:
content = { content = {
'event': 'agent_thought', 'event': 'agent_thought',
'data': { 'data': {
'id': message_agent_thought.id,
'task_id': self._task_id, 'task_id': self._task_id,
'message_id': self._message.id, 'message_id': self._message.id,
'chain_id': message_agent_thought.message_chain_id, 'chain_id': message_agent_thought.message_chain_id,
'agent_thought_id': message_agent_thought.id,
'position': message_agent_thought.position, 'position': message_agent_thought.position,
'thought': message_agent_thought.thought, 'thought': message_agent_thought.thought,
'tool': message_agent_thought.tool, 'tool': message_agent_thought.tool,
'tool_input': message_agent_thought.tool_input, 'tool_input': message_agent_thought.tool_input,
'observation': message_agent_thought.observation,
'answer': message_agent_thought.answer,
'mode': self._conversation.mode, 'mode': self._conversation.mode,
'conversation_id': self._conversation.id 'conversation_id': self._conversation.id
} }
...@@ -388,6 +392,15 @@ class PubHandler: ...@@ -388,6 +392,15 @@ class PubHandler:
def _is_stopped(self): def _is_stopped(self):
return redis_client.get(self._stopped_cache_key) is not None return redis_client.get(self._stopped_cache_key) is not None
@classmethod
def ping(cls, user: Union[Account | EndUser], task_id: str):
content = {
'event': 'ping'
}
channel = cls.generate_channel_name(user, task_id)
redis_client.publish(channel, json.dumps(content))
@classmethod @classmethod
def stop(cls, user: Union[Account | EndUser], task_id: str): def stop(cls, user: Union[Account | EndUser], task_id: str):
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
......
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union, Optional
import requests
from langchain.document_loaders import TextLoader, Docx2txtLoader from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document from langchain.schema import Document
...@@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader ...@@ -13,6 +14,9 @@ from core.data_loader.loader.pdf import PdfLoader
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.model import UploadFile from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class FileExtractor: class FileExtractor:
@classmethod @classmethod
...@@ -22,6 +26,25 @@ class FileExtractor: ...@@ -22,6 +26,25 @@ class FileExtractor:
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path) storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
return cls.load_from_file(file_path, return_text)
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
input_file = Path(file_path) input_file = Path(file_path)
delimiter = '\n' delimiter = '\n'
if input_file.suffix == '.xlsx': if input_file.suffix == '.xlsx':
......
import time
from typing import List, Optional, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
class FakeLLM(SimpleChatModel):
"""Fake ChatModel for testing purposes."""
streaming: bool = False
"""Whether to stream the results or not."""
response: str
origin_llm: Optional[BaseLanguageModel] = None
@property
def _llm_type(self) -> str:
return "fake-chat-model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
return self.response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"response": self.response}
def get_num_tokens(self, text: str) -> int:
return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
if self.streaming:
for token in output_str:
if run_manager:
run_manager.on_llm_new_token(token)
time.sleep(0.01)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
llm_output = {"token_usage": {
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
}}
return ChatResult(generations=[generation], llm_output=llm_output)
...@@ -10,6 +10,9 @@ from core.llm.provider.errors import ValidateFailedError ...@@ -10,6 +10,9 @@ from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider): class AzureProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
return [] return []
...@@ -50,9 +53,10 @@ class AzureProvider(BaseProvider): ...@@ -50,9 +53,10 @@ class AzureProvider(BaseProvider):
""" """
config = self.get_provider_api_key(model_id=model_id) config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure' config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002': if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 1 config['chunk_size'] = 16
else: else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config return config
...@@ -69,7 +73,7 @@ class AzureProvider(BaseProvider): ...@@ -69,7 +73,7 @@ class AzureProvider(BaseProvider):
except: except:
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview', 'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '', 'openai_api_base': '',
'openai_api_key': '' 'openai_api_key': ''
} }
...@@ -78,7 +82,7 @@ class AzureProvider(BaseProvider): ...@@ -78,7 +82,7 @@ class AzureProvider(BaseProvider):
if not config.get('openai_api_key'): if not config.get('openai_api_key'):
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview', 'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': '', 'openai_api_base': '',
'openai_api_key': '' 'openai_api_key': ''
} }
...@@ -100,7 +104,7 @@ class AzureProvider(BaseProvider): ...@@ -100,7 +104,7 @@ class AzureProvider(BaseProvider):
raise ValueError('Config must be a object.') raise ValueError('Config must be a object.')
if 'openai_api_version' not in config: if 'openai_api_version' not in config:
config['openai_api_version'] = '2023-03-15-preview' config['openai_api_version'] = AZURE_OPENAI_API_VERSION
self.check_embedding_model(credentials=config) self.check_embedding_model(credentials=config)
except ValidateFailedError as e: except ValidateFailedError as e:
...@@ -119,7 +123,7 @@ class AzureProvider(BaseProvider): ...@@ -119,7 +123,7 @@ class AzureProvider(BaseProvider):
""" """
return json.dumps({ return json.dumps({
'openai_api_type': 'azure', 'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview', 'openai_api_version': AZURE_OPENAI_API_VERSION,
'openai_api_base': config['openai_api_base'], 'openai_api_base': config['openai_api_base'],
'openai_api_key': self.encrypt_token(config['openai_api_key']) 'openai_api_key': self.encrypt_token(config['openai_api_key'])
}) })
......
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
from langchain.schema import BaseMessage, LLMResult from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, Tuple, Union
from pydantic import root_validator from pydantic import root_validator
...@@ -9,6 +10,11 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions ...@@ -9,6 +10,11 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI): class StreamableAzureChatOpenAI(AzureChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
...@@ -71,3 +77,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -71,3 +77,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
params['model_kwargs'] = model_kwargs params['model_kwargs'] = model_kwargs
return params return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call["arguments"] += _function_call["arguments"]
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
from pydantic import root_validator from pydantic import root_validator
...@@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions ...@@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI): class StreamableAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure" openai_api_type: str = "azure"
openai_api_version: str = "" openai_api_version: str = ""
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
......
from typing import List, Optional, Any, Dict from typing import List, Optional, Any, Dict
from httpx import Timeout
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
...@@ -12,6 +14,14 @@ class StreamableChatAnthropic(ChatAnthropic): ...@@ -12,6 +14,14 @@ class StreamableChatAnthropic(ChatAnthropic):
Wrapper around Anthropic's large language model. Wrapper around Anthropic's large language model.
""" """
default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0)
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions @handle_anthropic_exceptions
def generate( def generate(
self, self,
...@@ -37,3 +47,16 @@ class StreamableChatAnthropic(ChatAnthropic): ...@@ -37,3 +47,16 @@ class StreamableChatAnthropic(ChatAnthropic):
del params['presence_penalty'] del params['presence_penalty']
return params return params
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
\ No newline at end of file
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, Union, Tuple
from pydantic import root_validator from pydantic import root_validator
...@@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions ...@@ -11,6 +11,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableChatOpenAI(ChatOpenAI): class StreamableChatOpenAI(ChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List, Dict, Any, Mapping from typing import Optional, List, Dict, Any, Mapping, Union, Tuple
from langchain import OpenAI from langchain import OpenAI
from pydantic import root_validator from pydantic import root_validator
...@@ -10,6 +10,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions ...@@ -10,6 +10,10 @@ from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableOpenAI(OpenAI): class StreamableOpenAI(OpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 1
"""Maximum number of retries to make when generating."""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
......
This diff is collapsed.
import re
from typing import Type
from flask import current_app from flask import current_app
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import Field, BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
class DatasetTool(BaseTool): class DatasetRetrieverToolInput(BaseModel):
dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(BaseTool):
"""Tool for querying a Dataset.""" """Tool for querying a Dataset."""
name: str = "dataset"
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
tenant_id: str
dataset_id: str
k: int = 3
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description.replace('\n', '').replace('\r', '')
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description += '\nID of dataset MUST be ' + dataset.id
return cls(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs
)
dataset: Dataset def _run(self, dataset_id: str, query: str) -> str:
k: int = 2 pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, dataset_id, re.IGNORECASE)
if match:
dataset_id = match.group()
def _run(self, tool_input: str) -> str: dataset = db.session.query(Dataset).filter(
if self.dataset.indexing_technique == "economy": Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return f'[{self.name} failed to find dataset with id {dataset_id}.]'
if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
kw_table_index = KeywordTableIndex( kw_table_index = KeywordTableIndex(
dataset=self.dataset, dataset=dataset,
config=KeywordTableConfig( config=KeywordTableConfig(
max_keywords_per_chunk=5 max_keywords_per_chunk=5
) )
) )
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k}) documents = kw_table_index.search(query, search_kwargs={'k': self.k})
else: else:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
...@@ -39,49 +80,26 @@ class DatasetTool(BaseTool): ...@@ -39,49 +80,26 @@ class DatasetTool(BaseTool):
)) ))
vector_index = VectorIndex( vector_index = VectorIndex(
dataset=self.dataset, dataset=dataset,
config=current_app.config, config=current_app.config,
embeddings=embeddings embeddings=embeddings
) )
if self.k > 0:
documents = vector_index.search( documents = vector_index.search(
tool_input, query,
search_type='similarity', search_type='similarity',
search_kwargs={ search_kwargs={
'k': self.k 'k': self.k
} }
) )
else:
documents = []
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
hit_callback.on_tool_end(documents) hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents])) return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str: async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials( raise NotImplementedError()
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and tool_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.tool_name == self.get_provider_name().value
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]
class ToolValidateFailedError(Exception):
description = "Tool Provider Validate failed"
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from models.tool import ToolProviderName
class SerpAPIToolProvider(BaseToolProvider):
def get_provider_name(self) -> ToolProviderName:
"""
Returns the name of the provider.
:return:
"""
return ToolProviderName.SERPAPI
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider = self.get_provider(must_enabled=True)
if not tool_provider:
return None
credentials = tool_provider.credentials
if not credentials:
return None
if credentials.get('api_key'):
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
return credentials
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials = self.get_credentials()
if not credentials:
return None
return {
'serpapi_api_key': credentials.get('api_key')
}
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ToolValidateFailedError("SerpAPI api_key is required.")
api_key = credentials.get('api_key')
try:
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
except Exception as e:
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
return credentials
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
class ToolProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self._init_provider(tenant_id, provider_name)
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
if provider_name == 'serpapi':
return SerpAPIToolProvider(tenant_id)
else:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for Tool as a dictionary.
:param obfuscated:
:return:
"""
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
def encrypt_credentials(self, credentials: dict):
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
return self.provider.encrypt_credentials(credentials)
from langchain import SerpAPIWrapper
from pydantic import Field, BaseModel
class OptimizedSerpAPIInput(BaseModel):
query: str = Field(..., description="search query.")
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
@staticmethod
def _process_response(res: dict, num_results: int = 5) -> str:
"""Process response from SerpAPI."""
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
res["answer_box"] = res["answer_box"][0]
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
toret = res["sports_results"]["game_spotlight"]
elif (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
toret = res["shopping_results"][:3]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
toret = res["knowledge_graph"]["description"]
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
toret = ""
for result in res["organic_results"][:num_results]:
if "link" in result:
toret += "----------------\nlink: " + result["link"] + "\n"
if "snippet" in result:
toret += "snippet: " + result["snippet"] + "\n"
else:
toret = "No good search result found"
return "search result:\n" + toret
This diff is collapsed.
"""add is_universal in apps
Revision ID: 2beac44e5f5f
Revises: d3d503a3471c
Create Date: 2023-07-07 12:11:29.156057
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2beac44e5f5f'
down_revision = 'a5b56fb053ef'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_universal', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.drop_column('is_universal')
# ### end Alembic commands ###
"""add tool providers
Revision ID: 7ce5a52e4eee
Revises: 2beac44e5f5f
Create Date: 2023-07-10 10:26:50.074515
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '7ce5a52e4eee'
down_revision = '2beac44e5f5f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_providers',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
sa.Column('tool_name', sa.String(length=40), nullable=False),
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
sa.Column('is_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('sensitive_word_avoidance')
op.drop_table('tool_providers')
# ### end Alembic commands ###
...@@ -40,6 +40,7 @@ class App(db.Model): ...@@ -40,6 +40,7 @@ class App(db.Model):
api_rph = db.Column(db.Integer, nullable=False) api_rph = db.Column(db.Integer, nullable=False)
is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
...@@ -88,6 +89,7 @@ class AppModelConfig(db.Model): ...@@ -88,6 +89,7 @@ class AppModelConfig(db.Model):
user_input_form = db.Column(db.Text) user_input_form = db.Column(db.Text)
pre_prompt = db.Column(db.Text) pre_prompt = db.Column(db.Text)
agent_mode = db.Column(db.Text) agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
@property @property
def app(self): def app(self):
...@@ -116,14 +118,35 @@ class AppModelConfig(db.Model): ...@@ -116,14 +118,35 @@ class AppModelConfig(db.Model):
def more_like_this_dict(self) -> dict: def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@property
def sensitive_word_avoidance_dict(self) -> dict:
return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \
else {"enabled": False, "words": [], "canned_response": []}
@property @property
def user_input_form_list(self) -> dict: def user_input_form_list(self) -> dict:
return json.loads(self.user_input_form) if self.user_input_form else [] return json.loads(self.user_input_form) if self.user_input_form else []
@property @property
def agent_mode_dict(self) -> dict: def agent_mode_dict(self) -> dict:
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "tools": []} return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []}
def to_dict(self) -> dict:
return {
"provider": "",
"model_id": "",
"configs": {},
"opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list,
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
"speech_to_text": self.speech_to_text_dict,
"more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"model": self.model_dict,
"user_input_form": self.user_input_form_list,
"pre_prompt": self.pre_prompt,
"agent_mode": self.agent_mode_dict
}
class RecommendedApp(db.Model): class RecommendedApp(db.Model):
__tablename__ = 'recommended_apps' __tablename__ = 'recommended_apps'
...@@ -237,6 +260,9 @@ class Conversation(db.Model): ...@@ -237,6 +260,9 @@ class Conversation(db.Model):
if 'speech_to_text' in override_model_configs else {"enabled": False} if 'speech_to_text' in override_model_configs else {"enabled": False}
model_config['more_like_this'] = override_model_configs['more_like_this'] \ model_config['more_like_this'] = override_model_configs['more_like_this'] \
if 'more_like_this' in override_model_configs else {"enabled": False} if 'more_like_this' in override_model_configs else {"enabled": False}
model_config['sensitive_word_avoidance'] = override_model_configs['sensitive_word_avoidance'] \
if 'sensitive_word_avoidance' in override_model_configs \
else {"enabled": False, "words": [], "canned_response": []}
model_config['user_input_form'] = override_model_configs['user_input_form'] model_config['user_input_form'] = override_model_configs['user_input_form']
else: else:
model_config['configs'] = override_model_configs model_config['configs'] = override_model_configs
...@@ -253,6 +279,7 @@ class Conversation(db.Model): ...@@ -253,6 +279,7 @@ class Conversation(db.Model):
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
model_config['speech_to_text'] = app_model_config.speech_to_text_dict model_config['speech_to_text'] = app_model_config.speech_to_text_dict
model_config['more_like_this'] = app_model_config.more_like_this_dict model_config['more_like_this'] = app_model_config.more_like_this_dict
model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
model_config['user_input_form'] = app_model_config.user_input_form_list model_config['user_input_form'] = app_model_config.user_input_form_list
model_config['model_id'] = self.model_id model_config['model_id'] = self.model_id
...@@ -393,6 +420,11 @@ class Message(db.Model): ...@@ -393,6 +420,11 @@ class Message(db.Model):
def in_debug_mode(self): def in_debug_mode(self):
return self.override_model_configs is not None return self.override_model_configs is not None
@property
def agent_thoughts(self):
return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id)\
.order_by(MessageAgentThought.position.asc()).all()
class MessageFeedback(db.Model): class MessageFeedback(db.Model):
__tablename__ = 'message_feedbacks' __tablename__ = 'message_feedbacks'
......
import json
from enum import Enum
from sqlalchemy.dialects.postgresql import UUID
from extensions.ext_database import db
class ToolProviderName(Enum):
SERPAPI = 'serpapi'
@staticmethod
def value_of(value):
for member in ToolProviderName:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ToolProvider(db.Model):
__tablename__ = 'tool_providers'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
tool_name = db.Column(db.String(40), nullable=False)
encrypted_credentials = db.Column(db.Text, nullable=True)
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
@property
def credentials_is_set(self):
"""
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return self.encrypted_credentials is not None
@property
def credentials(self):
"""
Returns the decrypted config.
"""
return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None
...@@ -10,8 +10,8 @@ flask-session2==1.3.1 ...@@ -10,8 +10,8 @@ flask-session2==1.3.1
flask-cors==3.0.10 flask-cors==3.0.10
gunicorn~=20.1.0 gunicorn~=20.1.0
gevent~=22.10.2 gevent~=22.10.2
langchain==0.0.230 langchain==0.0.239
openai~=0.27.5 openai~=0.27.8
psycopg2-binary~=2.9.6 psycopg2-binary~=2.9.6
pycryptodome==3.17 pycryptodome==3.17
python-dotenv==1.0.0 python-dotenv==1.0.0
...@@ -36,3 +36,8 @@ pypdfium2==4.16.0 ...@@ -36,3 +36,8 @@ pypdfium2==4.16.0
resend~=0.5.1 resend~=0.5.1
pyjwt~=2.6.0 pyjwt~=2.6.0
anthropic~=0.3.4 anthropic~=0.3.4
newspaper3k==0.2.8
google-api-python-client==2.90.0
wikipedia==1.4.0
readabilipy==0.2.0
google-search-results==2.4.2
\ No newline at end of file
import re import re
import uuid import uuid
from core.agent.agent_executor import PlanningStrategy
from core.constant import llm_constant from core.constant import llm_constant
from models.account import Account from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
...@@ -31,6 +32,16 @@ MODELS_BY_APP_MODE = { ...@@ -31,6 +32,16 @@ MODELS_BY_APP_MODE = {
] ]
} }
SUPPORT_AGENT_MODELS = [
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
]
SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia"]
class AppModelConfigService: class AppModelConfigService:
@staticmethod @staticmethod
def is_dataset_exists(account: Account, dataset_id: str) -> bool: def is_dataset_exists(account: Account, dataset_id: str) -> bool:
...@@ -58,7 +69,8 @@ class AppModelConfigService: ...@@ -58,7 +69,8 @@ class AppModelConfigService:
if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \ if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \
llm_constant.max_context_token_length[model_name]: llm_constant.max_context_token_length[model_name]:
raise ValueError( raise ValueError(
"max_tokens must be an integer greater than 0 and not exceeding the maximum value of the corresponding model") "max_tokens must be an integer greater than 0 "
"and not exceeding the maximum value of the corresponding model")
# temperature # temperature
if 'temperature' not in cp: if 'temperature' not in cp:
...@@ -149,11 +161,6 @@ class AppModelConfigService: ...@@ -149,11 +161,6 @@ class AppModelConfigService:
if not isinstance(config["speech_to_text"]["enabled"], bool): if not isinstance(config["speech_to_text"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type") raise ValueError("enabled in speech_to_text must be of boolean type")
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
raise ValueError("provider not support speech to text")
# more_like_this # more_like_this
if 'more_like_this' not in config or not config["more_like_this"]: if 'more_like_this' not in config or not config["more_like_this"]:
config["more_like_this"] = { config["more_like_this"] = {
...@@ -169,6 +176,33 @@ class AppModelConfigService: ...@@ -169,6 +176,33 @@ class AppModelConfigService:
if not isinstance(config["more_like_this"]["enabled"], bool): if not isinstance(config["more_like_this"]["enabled"], bool):
raise ValueError("enabled in more_like_this must be of boolean type") raise ValueError("enabled in more_like_this must be of boolean type")
# sensitive_word_avoidance
if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
config["sensitive_word_avoidance"]["enabled"] = False
if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
config["sensitive_word_avoidance"]["words"] = ""
if not isinstance(config["sensitive_word_avoidance"]["words"], str):
raise ValueError("words in sensitive_word_avoidance must be of string type")
if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
config["sensitive_word_avoidance"]["canned_response"] = ""
if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
# model # model
if 'model' not in config: if 'model' not in config:
raise ValueError("model is required") raise ValueError("model is required")
...@@ -274,6 +308,12 @@ class AppModelConfigService: ...@@ -274,6 +308,12 @@ class AppModelConfigService:
if not isinstance(config["agent_mode"]["enabled"], bool): if not isinstance(config["agent_mode"]["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type") raise ValueError("enabled in agent_mode must be of boolean type")
if "strategy" not in config["agent_mode"] or not config["agent_mode"]["strategy"]:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]: if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]:
config["agent_mode"]["tools"] = [] config["agent_mode"]["tools"] = []
...@@ -282,8 +322,8 @@ class AppModelConfigService: ...@@ -282,8 +322,8 @@ class AppModelConfigService:
for tool in config["agent_mode"]["tools"]: for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key not in ["sensitive-word-avoidance", "dataset"]: if key not in SUPPORT_TOOLS:
raise ValueError("Keys in agent_mode.tools list can only be 'sensitive-word-avoidance' or 'dataset'") raise ValueError("Keys in agent_mode.tools must be in the specified tool list")
tool_item = tool[key] tool_item = tool[key]
...@@ -293,19 +333,7 @@ class AppModelConfigService: ...@@ -293,19 +333,7 @@ class AppModelConfigService:
if not isinstance(tool_item["enabled"], bool): if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "sensitive-word-avoidance": if key == "dataset":
if "words" not in tool_item or not tool_item["words"]:
tool_item["words"] = ""
if not isinstance(tool_item["words"], str):
raise ValueError("words in sensitive-word-avoidance must be of string type")
if "canned_response" not in tool_item or not tool_item["canned_response"]:
tool_item["canned_response"] = ""
if not isinstance(tool_item["canned_response"], str):
raise ValueError("canned_response in sensitive-word-avoidance must be of string type")
elif key == "dataset":
if 'id' not in tool_item: if 'id' not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
...@@ -324,6 +352,7 @@ class AppModelConfigService: ...@@ -324,6 +352,7 @@ class AppModelConfigService:
"suggested_questions_after_answer": config["suggested_questions_after_answer"], "suggested_questions_after_answer": config["suggested_questions_after_answer"],
"speech_to_text": config["speech_to_text"], "speech_to_text": config["speech_to_text"],
"more_like_this": config["more_like_this"], "more_like_this": config["more_like_this"],
"sensitive_word_avoidance": config["sensitive_word_avoidance"],
"model": { "model": {
"provider": config["model"]["provider"], "provider": config["model"]["provider"],
"name": config["model"]["name"], "name": config["model"]["name"],
......
...@@ -37,6 +37,8 @@ class CompletionService: ...@@ -37,6 +37,8 @@ class CompletionService:
if not query: if not query:
raise ValueError('query is required') raise ValueError('query is required')
query = query.replace('\x00', '')
conversation_id = args['conversation_id'] if 'conversation_id' in args else None conversation_id = args['conversation_id'] if 'conversation_id' in args else None
conversation = None conversation = None
...@@ -140,6 +142,7 @@ class CompletionService: ...@@ -140,6 +142,7 @@ class CompletionService:
suggested_questions=json.dumps(model_config['suggested_questions']), suggested_questions=json.dumps(model_config['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']), suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
more_like_this=json.dumps(model_config['more_like_this']), more_like_this=json.dumps(model_config['more_like_this']),
sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
model=json.dumps(model_config['model']), model=json.dumps(model_config['model']),
user_input_form=json.dumps(model_config['user_input_form']), user_input_form=json.dumps(model_config['user_input_form']),
pre_prompt=model_config['pre_prompt'], pre_prompt=model_config['pre_prompt'],
...@@ -171,7 +174,7 @@ class CompletionService: ...@@ -171,7 +174,7 @@ class CompletionService:
generate_worker_thread.start() generate_worker_thread.start()
# wait for 5 minutes to close the thread # wait for 10 minutes to close the thread
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id) cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming) return cls.compact_response(pubsub, streaming)
...@@ -179,9 +182,9 @@ class CompletionService: ...@@ -179,9 +182,9 @@ class CompletionService:
@classmethod @classmethod
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]): def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
if isinstance(user, Account): if isinstance(user, Account):
user = db.session.query(Account).get(user.id) user = db.session.query(Account).filter(Account.id == user.id).first()
elif isinstance(user, EndUser): elif isinstance(user, EndUser):
user = db.session.query(EndUser).get(user.id) user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
else: else:
raise Exception("Unknown user type") raise Exception("Unknown user type")
...@@ -226,12 +229,15 @@ class CompletionService: ...@@ -226,12 +229,15 @@ class CompletionService:
@classmethod @classmethod
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
# wait for 5 minutes to close the thread # wait for 10 minutes to close the thread
timeout = 300 timeout = 600
def close_pubsub(): def close_pubsub():
sleep_iterations = 0 sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive(): while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
time.sleep(1) time.sleep(1)
sleep_iterations += 1 sleep_iterations += 1
...@@ -369,7 +375,7 @@ class CompletionService: ...@@ -369,7 +375,7 @@ class CompletionService:
if len(value) > max_length: if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters') raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs return filtered_inputs
...@@ -418,6 +424,10 @@ class CompletionService: ...@@ -418,6 +424,10 @@ class CompletionService:
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought': elif event == 'agent_thought':
yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'ping':
yield "event: ping\n\n"
else:
yield "data: " + json.dumps(result) + "\n\n"
except ValueError as e: except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error if e.args[0] != "I/O operation on closed file.": # ignore this error
logging.exception(e) logging.exception(e)
...@@ -467,16 +477,14 @@ class CompletionService: ...@@ -467,16 +477,14 @@ class CompletionService:
def get_agent_thought_response_data(cls, data: dict): def get_agent_thought_response_data(cls, data: dict):
response_data = { response_data = {
'event': 'agent_thought', 'event': 'agent_thought',
'id': data.get('agent_thought_id'), 'id': data.get('id'),
'chain_id': data.get('chain_id'), 'chain_id': data.get('chain_id'),
'task_id': data.get('task_id'), 'task_id': data.get('task_id'),
'message_id': data.get('message_id'), 'message_id': data.get('message_id'),
'position': data.get('position'), 'position': data.get('position'),
'thought': data.get('thought'), 'thought': data.get('thought'),
'tool': data.get('tool'), # todo use real dataset obj replace it 'tool': data.get('tool'),
'tool_input': data.get('tool_input'), 'tool_input': data.get('tool_input'),
'observation': data.get('observation'),
'answer': data.get('answer') if not data.get('thought') else '',
'created_at': int(time.time()) 'created_at': int(time.time())
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment