Unverified Commit b554c607 authored by Yeuoly's avatar Yeuoly

Merge branch 'main' into feat/enterprise

parents d942668b dd961985
...@@ -90,7 +90,7 @@ class Config: ...@@ -90,7 +90,7 @@ class Config:
# ------------------------ # ------------------------
# General Configurations. # General Configurations.
# ------------------------ # ------------------------
self.CURRENT_VERSION = "0.5.6" self.CURRENT_VERSION = "0.5.7"
self.COMMIT_SHA = get_env('COMMIT_SHA') self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED" self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV') self.DEPLOY_ENV = get_env('DEPLOY_ENV')
......
from extensions.ext_database import db
from models.model import EndUser
def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
import json import json
from flask import current_app from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig from models.model import App, AppModelConfig
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
class AppParameterApi(AppApiResource): class AppParameterApi(Resource):
"""Resource for app variables.""" """Resource for app variables."""
variable_fields = { variable_fields = {
...@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource): ...@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields) 'system_parameters': fields.Nested(system_parameters_fields)
} }
@validate_app_token
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
...@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource): ...@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
} }
} }
class AppMetaApi(AppApiResource): class AppMetaApi(Resource):
def get(self, app_model: App, end_user): @validate_app_token
def get(self, app_model: App):
"""Get app meta""" """Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
......
import logging import logging
from flask import request from flask import request
from flask_restful import reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
...@@ -17,10 +17,10 @@ from controllers.service_api.app.error import ( ...@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
UnsupportedAudioTypeError, UnsupportedAudioTypeError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
...@@ -30,8 +30,9 @@ from services.errors.audio import ( ...@@ -30,8 +30,9 @@ from services.errors.audio import (
) )
class AudioApi(AppApiResource): class AudioApi(Resource):
def post(self, app_model: App, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']: if not app_model_config.speech_to_text_dict['enabled']:
...@@ -73,11 +74,11 @@ class AudioApi(AppApiResource): ...@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class TextApi(AppApiResource): class TextApi(Resource):
def post(self, app_model: App, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json') parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
...@@ -85,7 +86,7 @@ class TextApi(AppApiResource): ...@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
text=args['text'], text=args['text'],
end_user=args['user'], end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'), voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming'] streaming=args['streaming']
) )
......
...@@ -4,12 +4,11 @@ from collections.abc import Generator ...@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union from typing import Union
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
...@@ -19,17 +18,19 @@ from controllers.service_api.app.error import ( ...@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError, ProviderNotInitializeError,
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService from services.completion_service import CompletionService
class CompletionApi(AppApiResource): class CompletionApi(Resource):
def post(self, app_model, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
...@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource): ...@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False args['auto_generate_name'] = False
try: try:
...@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource): ...@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class CompletionStopApi(AppApiResource): class CompletionStopApi(Resource):
def post(self, app_model, end_user, task_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
class ChatApi(AppApiResource): class ChatApi(Resource):
def post(self, app_model, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -114,7 +102,6 @@ class ChatApi(AppApiResource): ...@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
...@@ -122,9 +109,6 @@ class ChatApi(AppApiResource): ...@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
response = CompletionService.completion( response = CompletionService.completion(
app_model=app_model, app_model=app_model,
...@@ -157,22 +141,12 @@ class ChatApi(AppApiResource): ...@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class ChatStopApi(AppApiResource): class ChatStopApi(Resource):
def post(self, app_model, end_user, task_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
from flask import request from flask_restful import Resource, marshal_with, reqparse
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
class ConversationApi(AppApiResource): class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError: except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource): class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try: try:
ConversationService.delete(app_model, conversation_id, end_user) ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
...@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource): ...@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204 return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource): class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource): ...@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json') parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return ConversationService.rename( return ConversationService.rename(
app_model, app_model,
......
from flask import request from flask import request
from flask_restful import marshal_with from flask_restful import Resource, marshal_with
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
FileTooLargeError, FileTooLargeError,
NoFileUploadedError, NoFileUploadedError,
TooManyFilesError, TooManyFilesError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService from services.file_service import FileService
class FileApi(AppApiResource): class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields) @marshal_with(file_fields)
def post(self, app_model, end_user): def post(self, app_model: App, end_user: EndUser):
file = request.files['file'] file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file # check file
if 'file' not in request.files: if 'file' not in request.files:
......
from flask_restful import fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message from models.model import App, EndUser
from services.message_service import MessageService from services.message_service import MessageService
class MessageListApi(AppApiResource): class MessageListApi(Resource):
feedback_fields = { feedback_fields = {
'rating': fields.String 'rating': fields.String
} }
...@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource): ...@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields)) 'data': fields.List(fields.Nested(message_fields))
} }
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource): ...@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return MessageService.pagination_by_first_id(app_model, end_user, return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit']) args['conversation_id'], args['first_id'], args['limit'])
...@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource): ...@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.") raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(AppApiResource): class MessageFeedbackApi(Resource):
def post(self, app_model, end_user, message_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating']) MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
...@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource): ...@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'} return {'result': 'success'}
class MessageSuggestedApi(AppApiResource): class MessageSuggestedApi(Resource):
def get(self, app_model, end_user, message_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
try: try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, app_model=app_model,
user=user, user=end_user,
message_id=message_id, message_id=message_id,
check_enabled=False check_enabled=False
) )
......
from collections.abc import Callable
from datetime import datetime from datetime import datetime
from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_restful import Resource from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import _get_user from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService
def validate_app_token(view=None): class WhereisUserArg(Enum):
def decorator(view): """
@wraps(view) Enum for whereis_user_arg.
def decorated(*args, **kwargs): """
QUERY = 'query'
JSON = 'json'
FORM = 'form'
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app') api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first() app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
...@@ -29,15 +47,34 @@ def validate_app_token(view=None): ...@@ -29,15 +47,34 @@ def validate_app_token(view=None):
if not app_model.enable_api: if not app_model.enable_api:
raise NotFound() raise NotFound()
return view(app_model, None, *args, **kwargs) kwargs['app_model'] = app_model
return decorated
if view: if fetch_user_arg:
return decorator(view) if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None
# if view is None, it means that the decorator is used without parentheses if not user_id and fetch_user_arg.required:
# use the decorator as a function for method_decorators raise ValueError("Arg user must be provided.")
return decorator
if user_id:
user_id = str(user_id)
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def cloud_edition_billing_resource_check(resource: str, def cloud_edition_billing_resource_check(resource: str,
...@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None): ...@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
return api_token return api_token
class AppApiResource(Resource): def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
method_decorators = [validate_app_token] """
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource): class DatasetApiResource(Resource):
......
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param model_config:
:param messages:
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass
This diff is collapsed.
This diff is collapsed.
import json
import logging import logging
from typing import cast from typing import cast
...@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner): ...@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables # convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
message_chain = self._init_message_chain(
message=message,
query=query
)
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
...@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner): ...@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables 'pool': db_variables.variables
}) })
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage: message: Message) -> LLMUsage:
""" """
......
...@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner ...@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
......
...@@ -175,7 +175,7 @@ class GenerateTaskPipeline: ...@@ -175,7 +175,7 @@ class GenerateTaskPipeline:
'id': self._message.id, 'id': self._message.id,
'message_id': self._message.id, 'message_id': self._message.id,
'mode': self._conversation.mode, 'mode': self._conversation.mode,
'answer': event.llm_result.message.content, 'answer': self._task_state.llm_result.message.content,
'metadata': {}, 'metadata': {},
'created_at': int(self._message.created_at.timestamp()) 'created_at': int(self._message.created_at.timestamp())
} }
......
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
import logging
from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (
AgentEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool
\ No newline at end of file
...@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun ...@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain): class LLMChain(LCLLMChain):
......
...@@ -12,9 +12,9 @@ from pydantic import root_validator ...@@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent): class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
......
...@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy ...@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......
import enum
import logging import logging
from typing import Optional, Union from typing import Optional, Union
...@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks ...@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.entities.agent_entities import PlanningStrategy
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas ...@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
model_config: ModelConfigEntity model_config: ModelConfigEntity
...@@ -62,28 +53,7 @@ class AgentExecutor: ...@@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent() self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT: if self.configuration.strategy == PlanningStrategy.ROUTER:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool) if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)] or isinstance(t, DatasetMultiRetrieverTool)]
......
...@@ -2,9 +2,10 @@ from typing import Optional, cast ...@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
- bedrock - bedrock
- togetherai - togetherai
- ollama - ollama
- mistralai
- replicate - replicate
- huggingface_hub - huggingface_hub
- zhipuai - zhipuai
......
...@@ -4,7 +4,6 @@ from typing import Any, Optional ...@@ -4,7 +4,6 @@ from typing import Any, Optional
import requests import requests
from flask import current_app from flask import current_app
from flask_login import current_user
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
...@@ -43,7 +42,7 @@ class NotionExtractor(BaseExtractor): ...@@ -43,7 +42,7 @@ class NotionExtractor(BaseExtractor):
if notion_access_token: if notion_access_token:
self._notion_access_token = notion_access_token self._notion_access_token = notion_access_token
else: else:
self._notion_access_token = self._get_access_token(current_user.current_tenant_id, self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id) self._notion_workspace_id)
if not self._notion_access_token: if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
......
import base64
import hashlib
import hmac
import json
import queue
import ssl
from datetime import datetime
from time import mktime
from typing import Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket
class SparkLLMClient:
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
domain = 'spark-api.xf-yun.com'
endpoint = 'chat'
if api_domain:
domain = api_domain
if model_name == 'spark-v3':
endpoint = 'multimodal'
model_api_configs = {
'spark': {
'version': 'v1.1',
'chat_domain': 'general'
},
'spark-v2': {
'version': 'v2.1',
'chat_domain': 'generalv2'
},
'spark-v3': {
'version': 'v3.1',
'chat_domain': 'generalv3'
},
'spark-v3.5': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
}
}
api_version = model_api_configs[model_name]['version']
self.chat_domain = model_api_configs[model_name]['chat_domain']
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,
urlparse(self.api_base).path,
self.api_base,
api_key,
api_secret
)
self.queue = queue.Queue()
self.blocking_message = ''
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
# generate timestamp by RFC1123
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + path + " HTTP/1.1"
# encrypt using hmac-sha256
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
}
# generate url
url = api_base + '?' + urlencode(v)
return url
def run(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None, streaming: bool = False):
websocket.enableTrace(False)
ws = websocket.WebSocketApp(
self.ws_url,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
on_open=self.on_open
)
ws.messages = messages
ws.user_id = user_id
ws.model_kwargs = model_kwargs
ws.streaming = streaming
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_error(self, ws, error):
self.queue.put({
'status_code': error.status_code,
'error': error.resp_body.decode('utf-8')
})
ws.close()
def on_close(self, ws, close_status_code, close_reason):
self.queue.put({'done': True})
def on_open(self, ws):
self.blocking_message = ''
data = json.dumps(self.gen_params(
messages=ws.messages,
user_id=ws.user_id,
model_kwargs=ws.model_kwargs
))
ws.send(data)
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
self.queue.put({
'status_code': 400,
'error': f"Code: {code}, Error: {data['header']['message']}"
})
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
if ws.streaming:
self.queue.put({'data': content})
else:
self.blocking_message += content
if status == 2:
if not ws.streaming:
self.queue.put({'data': self.blocking_message})
ws.close()
def gen_params(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None) -> dict:
data = {
"header": {
"app_id": self.app_id,
"uid": user_id
},
"parameter": {
"chat": {
"domain": self.chat_domain
}
},
"payload": {
"message": {
"text": messages
}
}
}
if model_kwargs:
data['parameter']['chat'].update(model_kwargs)
return data
def subscribe(self):
while True:
content = self.queue.get()
if 'error' in content:
if content['status_code'] == 401:
raise SparkError('[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.')
elif content['status_code'] == 403:
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions.")
else:
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
if 'data' not in content:
break
yield content
class SparkError(Exception):
pass
from datetime import datetime
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class DatetimeToolInput(BaseModel):
type: str = Field(..., description="Type for current time, must be: datetime.")
class DatetimeTool(BaseTool):
"""Tool for querying current datetime."""
name: str = "current_datetime"
args_schema: type[BaseModel] = DatetimeToolInput
description: str = "A tool when you want to get the current date, time, week, month or year, " \
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
def _run(self, type: str) -> str:
# get current time
current_time = datetime.utcnow()
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and tool_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.tool_name == self.get_provider_name().value
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]
class ToolValidateFailedError(Exception):
description = "Tool Provider Validate failed"
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from models.tool import ToolProviderName
class SerpAPIToolProvider(BaseToolProvider):
def get_provider_name(self) -> ToolProviderName:
"""
Returns the name of the provider.
:return:
"""
return ToolProviderName.SERPAPI
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider = self.get_provider(must_enabled=True)
if not tool_provider:
return None
credentials = tool_provider.credentials
if not credentials:
return None
if credentials.get('api_key'):
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
return credentials
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials = self.get_credentials()
if not credentials:
return None
return {
'serpapi_api_key': credentials.get('api_key')
}
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ToolValidateFailedError("SerpAPI api_key is required.")
api_key = credentials.get('api_key')
try:
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
except Exception as e:
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
return credentials
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
class ToolProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self._init_provider(tenant_id, provider_name)
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
if provider_name == 'serpapi':
return SerpAPIToolProvider(tenant_id)
else:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for Tool as a dictionary.
:param obfuscated:
:return:
"""
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
def encrypt_credentials(self, credentials: dict):
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
return self.provider.encrypt_credentials(credentials)
from langchain import SerpAPIWrapper
from pydantic import BaseModel, Field
class OptimizedSerpAPIInput(BaseModel):
query: str = Field(..., description="search query.")
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
@staticmethod
def _process_response(res: dict, num_results: int = 5) -> str:
"""Process response from SerpAPI."""
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
res["answer_box"] = res["answer_box"][0]
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
toret = res["sports_results"]["game_spotlight"]
elif (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
toret = res["shopping_results"][:3]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
toret = res["knowledge_graph"]["description"]
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
toret = ""
for result in res["organic_results"][:num_results]:
if "link" in result:
toret += "----------------\nlink: " + result["link"] + "\n"
if "snippet" in result:
toret += "snippet: " + result["snippet"] + "\n"
else:
toret = "No good search result found"
return "search result:\n" + toret
This diff is collapsed.
...@@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController): ...@@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController):
en_US='The api key', en_US='The api key',
zh_Hans='api key的值' zh_Hans='api key的值'
) )
),
'api_key_header_prefix': ToolProviderCredentials(
name='api_key_header_prefix',
required=False,
default='basic',
type=ToolProviderCredentials.CredentialsType.SELECT,
help=I18nObject(
en_US='The prefix of the api key header',
zh_Hans='api key header 的前缀'
),
options=[
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
]
) )
} }
elif auth_type == ApiProviderAuthType.NONE: elif auth_type == ApiProviderAuthType.NONE:
......
...@@ -62,6 +62,17 @@ class ApiTool(Tool): ...@@ -62,6 +62,17 @@ class ApiTool(Tool):
if 'api_key_value' not in credentials: if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value') raise ToolProviderCredentialValidationError('Missing api_key_value')
elif not isinstance(credentials['api_key_value'], str):
raise ToolProviderCredentialValidationError('api_key_value must be a string')
if 'api_key_header_prefix' in credentials:
api_key_header_prefix = credentials['api_key_header_prefix']
if api_key_header_prefix == 'basic':
credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}'
elif api_key_header_prefix == 'bearer':
credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
elif api_key_header_prefix == 'custom':
pass
headers[api_key_header] = credentials['api_key_value'] headers[api_key_header] = credentials['api_key_value']
......
...@@ -4,7 +4,7 @@ from langchain.tools import BaseTool ...@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
...@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool): ...@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
@staticmethod @staticmethod
def get_dataset_tools(tenant_id: str, def get_dataset_tools(tenant_id: str,
dataset_ids: list[str], dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity, retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']: ) -> list['DatasetRetrieverTool']:
""" """
get dataset tool get dataset tool
""" """
...@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool): ...@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
) )
# restore retrieve strategy # restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools # convert langchain tools to Tools
tools = [] tools = []
for langchain_tool in langchain_tools: for langchain_tool in langchain_tools:
...@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool): ...@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
llm=langchain_tool.description), llm=langchain_tool.description),
runtime=DatasetRetrieverTool.Runtime() runtime=DatasetRetrieverTool.Runtime()
) )
tools.append(tool) tools.append(tool)
return tools return tools
...@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool): ...@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(self) -> list[ToolParameter]:
return [ return [
ToolParameter(name='query', ToolParameter(name='query',
label=I18nObject(en_US='', zh_Hans=''), label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''), human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.', llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True, required=True,
default=''), default=''),
] ]
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
...@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool): ...@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
query = tool_parameters.get('query', None) query = tool_parameters.get('query', None)
if not query: if not query:
return self.create_text_message(text='please input query') return self.create_text_message(text='please input query')
# invoke dataset retriever tool # invoke dataset retriever tool
result = self.langchain_tool._run(query=query) result = self.langchain_tool._run(query=query)
...@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool): ...@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
""" """
validate the credentials for dataset retriever tool validate the credentials for dataset retriever tool
""" """
pass pass
\ No newline at end of file
...@@ -7,23 +7,14 @@ import subprocess ...@@ -7,23 +7,14 @@ import subprocess
import tempfile import tempfile
import unicodedata import unicodedata
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any
import requests import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex from regex import regex
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
...@@ -36,106 +27,6 @@ TEXT: ...@@ -36,106 +27,6 @@ TEXT:
""" """
class WebReaderToolInput(BaseModel):
url: str = Field(..., description="URL of the website to read")
summary: bool = Field(
default=False,
description="When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor: int = Field(
default=0,
description="Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
"The value cannot exceed 24000.",
)
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "web_reader"
args_schema: type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
"there is no need to use."
page_contents: str = None
url: str = None
max_chunk_length: int = 4000
summary_chunk_tokens: int = 4000
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
continue_reading: bool = True
model_config: ModelConfigEntity
model_parameters: dict[str, Any]
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
if not self.page_contents or self.url != url:
page_contents = get_url(url)
self.page_contents = page_contents
self.url = url
else:
page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators
)
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
return "No content found."
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]
chain = self.get_summary_chain()
try:
page_contents = chain.run(docs)
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
async def _arun(self, url: str) -> str:
raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.PROMPT,
parameters=self.model_parameters
)
refine_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.REFINE_PROMPT,
parameters=self.model_parameters
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str: def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length] return text[cursor: cursor + max_length]
......
import re import re
import uuid import uuid
from core.agent.agent_executor import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
......
from flask import current_app
from flask_login import current_user from flask_login import current_user
from extensions.ext_database import db from extensions.ext_database import db
...@@ -31,7 +33,15 @@ class WorkspaceService: ...@@ -31,7 +33,15 @@ class WorkspaceService:
can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): if can_replace_logo and TenantService.has_roles(tenant,
tenant_info['custom_config'] = tenant.custom_config_dict [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
base_url = current_app.config.get('FILES_URL')
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
tenant_info['custom_config'] = {
'remove_webapp_brand': remove_webapp_brand,
'replace_webapp_logo': replace_webapp_logo,
}
return tenant_info return tenant_info
...@@ -2,7 +2,7 @@ version: '3.1' ...@@ -2,7 +2,7 @@ version: '3.1'
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.5.6 image: langgenius/dify-api:0.5.7
restart: always restart: always
environment: environment:
# Startup mode, 'api' starts the API server. # Startup mode, 'api' starts the API server.
...@@ -135,7 +135,7 @@ services: ...@@ -135,7 +135,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.5.6 image: langgenius/dify-api:0.5.7
restart: always restart: always
environment: environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue. # Startup mode, 'worker' starts the Celery worker for processing the queue.
...@@ -206,7 +206,7 @@ services: ...@@ -206,7 +206,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.5.6 image: langgenius/dify-web:0.5.7
restart: always restart: always
environment: environment:
EDITION: SELF_HOSTED EDITION: SELF_HOSTED
......
...@@ -3,16 +3,13 @@ import React from 'react' ...@@ -3,16 +3,13 @@ import React from 'react'
import Spinner from '../spinner' import Spinner from '../spinner'
export type IButtonProps = { export type IButtonProps = {
/** type?: string
* The style of the button
*/
type?: 'primary' | 'warning' | (string & {})
className?: string className?: string
disabled?: boolean disabled?: boolean
loading?: boolean loading?: boolean
tabIndex?: number tabIndex?: number
children: React.ReactNode children: React.ReactNode
onClick?: MouseEventHandler<HTMLButtonElement> onClick?: MouseEventHandler<HTMLDivElement>
} }
const Button: FC<IButtonProps> = ({ const Button: FC<IButtonProps> = ({
...@@ -38,16 +35,15 @@ const Button: FC<IButtonProps> = ({ ...@@ -38,16 +35,15 @@ const Button: FC<IButtonProps> = ({
} }
return ( return (
<button <div
className={`btn ${style} ${className && className}`} className={`btn ${style} ${className && className}`}
tabIndex={tabIndex} tabIndex={tabIndex}
disabled={disabled} onClick={disabled ? undefined : onClick}
onClick={onClick}
> >
{children} {children}
{/* Spinner is hidden when loading is false */} {/* Spinner is hidden when loading is false */}
<Spinner loading={loading} className='!text-white !h-3 !w-3 !border-2 !ml-1' /> <Spinner loading={loading} className='!text-white !h-3 !w-3 !border-2 !ml-1' />
</button> </div>
) )
} }
......
...@@ -16,8 +16,6 @@ import { ...@@ -16,8 +16,6 @@ import {
updateCurrentWorkspace, updateCurrentWorkspace,
} from '@/service/common' } from '@/service/common'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import { API_PREFIX } from '@/config'
import { getPurifyHref } from '@/utils'
const ALLOW_FILE_EXTENSIONS = ['svg', 'png'] const ALLOW_FILE_EXTENSIONS = ['svg', 'png']
...@@ -123,7 +121,7 @@ const CustomWebAppBrand = () => { ...@@ -123,7 +121,7 @@ const CustomWebAppBrand = () => {
POWERED BY POWERED BY
{ {
webappLogo webappLogo
? <img key={webappLogo} src={`${getPurifyHref(API_PREFIX.slice(0, -12))}/files/workspaces/${currentWorkspace.id}/webapp-logo`} alt='logo' className='ml-2 block w-auto h-5' /> ? <img key={webappLogo} src={webappLogo} alt='logo' className='ml-2 block w-auto h-5' />
: <LogoSite className='ml-2 !h-5' /> : <LogoSite className='ml-2 !h-5' />
} }
</div> </div>
......
...@@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran ...@@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran
</Col> </Col>
<Col sticky> <Col sticky>
### Request Example ### Request Example
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
......
...@@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ...@@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success - `result` (string) 固定返回 success
</Col> </Col>
<Col sticky> <Col sticky>
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
......
...@@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to ...@@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to
</Col> </Col>
<Col sticky> <Col sticky>
### Request Example ### Request Example
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{"user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{"user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
...@@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to ...@@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to
- (string) url of icon - (string) url of icon
</Col> </Col>
<Col> <Col>
<CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}> <CodeGroup title="Request" tag="GET" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \ curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}' -H 'Authorization: Bearer {api_key}'
``` ```
......
...@@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ...@@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success - `result` (string) 固定返回 success
</Col> </Col>
<Col sticky> <Col sticky>
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
...@@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ...@@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- (string) 图标URL - (string) 图标URL
</Col> </Col>
<Col> <Col>
<CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}> <CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \ curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}' -H 'Authorization: Bearer {api_key}'
``` ```
......
...@@ -26,13 +26,12 @@ const Apps: FC = () => { ...@@ -26,13 +26,12 @@ const Apps: FC = () => {
const { isCurrentWorkspaceManager } = useAppContext() const { isCurrentWorkspaceManager } = useAppContext()
const router = useRouter() const router = useRouter()
const { hasEditPermission } = useContext(ExploreContext) const { hasEditPermission } = useContext(ExploreContext)
const allCategoriesEn = t('explore.apps.allCategories', { lng: 'en' }) const allCategoriesEn = t('explore.apps.allCategories')
const [currCategory, setCurrCategory] = useTabSearchParams({ const [currCategory, setCurrCategory] = useTabSearchParams({
defaultTab: allCategoriesEn, defaultTab: allCategoriesEn,
}) })
const { const {
data: { categories, allList }, data: { categories, allList },
isLoading,
} = useSWR( } = useSWR(
['/explore/apps'], ['/explore/apps'],
() => () =>
...@@ -90,7 +89,7 @@ const Apps: FC = () => { ...@@ -90,7 +89,7 @@ const Apps: FC = () => {
} }
} }
if (!isLoading) { if (!categories) {
return ( return (
<div className="flex h-full items-center"> <div className="flex h-full items-center">
<Loading type="area" /> <Loading type="area" />
......
...@@ -3,11 +3,13 @@ import type { FC } from 'react' ...@@ -3,11 +3,13 @@ import type { FC } from 'react'
import React from 'react' import React from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from 'classnames' import cn from 'classnames'
import Tooltip from '../../base/tooltip'
import { HelpCircle } from '../../base/icons/src/vender/line/general'
import type { Credential } from '@/app/components/tools/types' import type { Credential } from '@/app/components/tools/types'
import Drawer from '@/app/components/base/drawer-plus' import Drawer from '@/app/components/base/drawer-plus'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Radio from '@/app/components/base/radio/ui' import Radio from '@/app/components/base/radio/ui'
import { AuthType } from '@/app/components/tools/types' import { AuthHeaderPrefix, AuthType } from '@/app/components/tools/types'
type Props = { type Props = {
credential: Credential credential: Credential
...@@ -18,9 +20,9 @@ const keyClassNames = 'py-2 leading-5 text-sm font-medium text-gray-900' ...@@ -18,9 +20,9 @@ const keyClassNames = 'py-2 leading-5 text-sm font-medium text-gray-900'
type ItemProps = { type ItemProps = {
text: string text: string
value: AuthType value: AuthType | AuthHeaderPrefix
isChecked: boolean isChecked: boolean
onClick: (value: AuthType) => void onClick: (value: AuthType | AuthHeaderPrefix) => void
} }
const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => { const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => {
...@@ -31,7 +33,6 @@ const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => { ...@@ -31,7 +33,6 @@ const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => {
> >
<Radio isChecked={isChecked} /> <Radio isChecked={isChecked} />
<div className='text-sm font-normal text-gray-900'>{text}</div> <div className='text-sm font-normal text-gray-900'>{text}</div>
</div> </div>
) )
} }
...@@ -43,6 +44,7 @@ const ConfigCredential: FC<Props> = ({ ...@@ -43,6 +44,7 @@ const ConfigCredential: FC<Props> = ({
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const [tempCredential, setTempCredential] = React.useState<Credential>(credential) const [tempCredential, setTempCredential] = React.useState<Credential>(credential)
return ( return (
<Drawer <Drawer
isShow isShow
...@@ -62,20 +64,59 @@ const ConfigCredential: FC<Props> = ({ ...@@ -62,20 +64,59 @@ const ConfigCredential: FC<Props> = ({
text={t('tools.createTool.authMethod.types.none')} text={t('tools.createTool.authMethod.types.none')}
value={AuthType.none} value={AuthType.none}
isChecked={tempCredential.auth_type === AuthType.none} isChecked={tempCredential.auth_type === AuthType.none}
onClick={value => setTempCredential({ ...tempCredential, auth_type: value })} onClick={value => setTempCredential({ ...tempCredential, auth_type: value as AuthType })}
/> />
<SelectItem <SelectItem
text={t('tools.createTool.authMethod.types.api_key')} text={t('tools.createTool.authMethod.types.api_key')}
value={AuthType.apiKey} value={AuthType.apiKey}
isChecked={tempCredential.auth_type === AuthType.apiKey} isChecked={tempCredential.auth_type === AuthType.apiKey}
onClick={value => setTempCredential({ ...tempCredential, auth_type: value })} onClick={value => setTempCredential({
...tempCredential,
auth_type: value as AuthType,
api_key_header: tempCredential.api_key_header || 'Authorization',
api_key_value: tempCredential.api_key_value || '',
api_key_header_prefix: tempCredential.api_key_header_prefix || AuthHeaderPrefix.custom,
})}
/> />
</div> </div>
</div> </div>
{tempCredential.auth_type === AuthType.apiKey && ( {tempCredential.auth_type === AuthType.apiKey && (
<> <>
<div className={keyClassNames}>{t('tools.createTool.authHeaderPrefix.title')}</div>
<div className='flex space-x-3'>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.basic')}
value={AuthHeaderPrefix.basic}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.basic}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.bearer')}
value={AuthHeaderPrefix.bearer}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.bearer}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.custom')}
value={AuthHeaderPrefix.custom}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.custom}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
</div>
<div> <div>
<div className={keyClassNames}>{t('tools.createTool.authMethod.key')}</div> <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
{t('tools.createTool.authMethod.key')}
<Tooltip
selector='model-page-system-reasoning-model-tip'
htmlContent={
<div className='w-[261px] text-gray-500'>
{t('tools.createTool.authMethod.keyTooltip')}
</div>
}
>
<HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400'/>
</Tooltip>
</div>
<input <input
value={tempCredential.api_key_header} value={tempCredential.api_key_header}
onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })} onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
...@@ -83,7 +124,6 @@ const ConfigCredential: FC<Props> = ({ ...@@ -83,7 +124,6 @@ const ConfigCredential: FC<Props> = ({
placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!} placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!}
/> />
</div> </div>
<div> <div>
<div className={keyClassNames}>{t('tools.createTool.authMethod.value')}</div> <div className={keyClassNames}>{t('tools.createTool.authMethod.value')}</div>
<input <input
......
...@@ -8,7 +8,7 @@ import { clone } from 'lodash-es' ...@@ -8,7 +8,7 @@ import { clone } from 'lodash-es'
import cn from 'classnames' import cn from 'classnames'
import { LinkExternal02, Settings01 } from '../../base/icons/src/vender/line/general' import { LinkExternal02, Settings01 } from '../../base/icons/src/vender/line/general'
import type { Credential, CustomCollectionBackend, CustomParamSchema, Emoji } from '../types' import type { Credential, CustomCollectionBackend, CustomParamSchema, Emoji } from '../types'
import { AuthType } from '../types' import { AuthHeaderPrefix, AuthType } from '../types'
import GetSchema from './get-schema' import GetSchema from './get-schema'
import ConfigCredentials from './config-credentials' import ConfigCredentials from './config-credentials'
import TestApi from './test-api' import TestApi from './test-api'
...@@ -37,6 +37,7 @@ const EditCustomCollectionModal: FC<Props> = ({ ...@@ -37,6 +37,7 @@ const EditCustomCollectionModal: FC<Props> = ({
const { t } = useTranslation() const { t } = useTranslation()
const isAdd = !payload const isAdd = !payload
const isEdit = !!payload const isEdit = !!payload
const [editFirst, setEditFirst] = useState(!isAdd) const [editFirst, setEditFirst] = useState(!isAdd)
const [paramsSchemas, setParamsSchemas] = useState<CustomParamSchema[]>(payload?.tools || []) const [paramsSchemas, setParamsSchemas] = useState<CustomParamSchema[]>(payload?.tools || [])
const [customCollection, setCustomCollection, getCustomCollection] = useGetState<CustomCollectionBackend>(isAdd const [customCollection, setCustomCollection, getCustomCollection] = useGetState<CustomCollectionBackend>(isAdd
...@@ -44,6 +45,8 @@ const EditCustomCollectionModal: FC<Props> = ({ ...@@ -44,6 +45,8 @@ const EditCustomCollectionModal: FC<Props> = ({
provider: '', provider: '',
credentials: { credentials: {
auth_type: AuthType.none, auth_type: AuthType.none,
api_key_header: 'Authorization',
api_key_header_prefix: AuthHeaderPrefix.basic,
}, },
icon: { icon: {
content: '🕵️', content: '🕵️',
......
...@@ -3,7 +3,7 @@ import type { FC } from 'react' ...@@ -3,7 +3,7 @@ import type { FC } from 'react'
import React, { useEffect, useState } from 'react' import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from 'classnames' import cn from 'classnames'
import { CollectionType, LOC } from '../types' import { AuthHeaderPrefix, AuthType, CollectionType, LOC } from '../types'
import type { Collection, CustomCollectionBackend, Tool } from '../types' import type { Collection, CustomCollectionBackend, Tool } from '../types'
import Loading from '../../base/loading' import Loading from '../../base/loading'
import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows' import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows'
...@@ -53,6 +53,10 @@ const ToolList: FC<Props> = ({ ...@@ -53,6 +53,10 @@ const ToolList: FC<Props> = ({
(async () => { (async () => {
if (collection.type === CollectionType.custom) { if (collection.type === CollectionType.custom) {
const res = await fetchCustomCollection(collection.name) const res = await fetchCustomCollection(collection.name)
if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
if (res.credentials.api_key_value)
res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
}
setCustomCollection({ setCustomCollection({
...res, ...res,
provider: collection.name, provider: collection.name,
......
...@@ -9,10 +9,17 @@ export enum AuthType { ...@@ -9,10 +9,17 @@ export enum AuthType {
apiKey = 'api_key', apiKey = 'api_key',
} }
export enum AuthHeaderPrefix {
basic = 'basic',
bearer = 'bearer',
custom = 'custom',
}
export type Credential = { export type Credential = {
'auth_type': AuthType 'auth_type': AuthType
'api_key_header'?: string 'api_key_header'?: string
'api_key_value'?: string 'api_key_value'?: string
'api_key_header_prefix'?: AuthHeaderPrefix
} }
export enum CollectionType { export enum CollectionType {
......
...@@ -51,6 +51,7 @@ const translation = { ...@@ -51,6 +51,7 @@ const translation = {
authMethod: { authMethod: {
title: 'Authorization method', title: 'Authorization method',
type: 'Authorization type', type: 'Authorization type',
keyTooltip: 'Http Header Key, You can leave it with "Authorization" if you have no idea what it is or set it to a custom value',
types: { types: {
none: 'None', none: 'None',
api_key: 'API Key', api_key: 'API Key',
...@@ -60,6 +61,14 @@ const translation = { ...@@ -60,6 +61,14 @@ const translation = {
key: 'Key', key: 'Key',
value: 'Value', value: 'Value',
}, },
authHeaderPrefix: {
title: 'Auth Type',
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Privacy policy', privacyPolicy: 'Privacy policy',
privacyPolicyPlaceholder: 'Please enter privacy policy', privacyPolicyPlaceholder: 'Please enter privacy policy',
}, },
......
...@@ -58,6 +58,13 @@ const translation = { ...@@ -58,6 +58,13 @@ const translation = {
key: 'Chave', key: 'Chave',
value: 'Valor', value: 'Valor',
}, },
authHeaderPrefix: {
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Política de Privacidade', privacyPolicy: 'Política de Privacidade',
privacyPolicyPlaceholder: 'Digite a política de privacidade', privacyPolicyPlaceholder: 'Digite a política de privacidade',
}, },
......
...@@ -58,6 +58,13 @@ const translation = { ...@@ -58,6 +58,13 @@ const translation = {
key: 'Ключ', key: 'Ключ',
value: 'Значення', value: 'Значення',
}, },
authHeaderPrefix: {
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Політика конфіденційності', privacyPolicy: 'Політика конфіденційності',
privacyPolicyPlaceholder: 'Введіть політику конфіденційності', privacyPolicyPlaceholder: 'Введіть політику конфіденційності',
}, },
......
...@@ -51,6 +51,7 @@ const translation = { ...@@ -51,6 +51,7 @@ const translation = {
authMethod: { authMethod: {
title: '鉴权方法', title: '鉴权方法',
type: '鉴权类型', type: '鉴权类型',
keyTooltip: 'HTTP 头部名称,如果你不知道是什么,可以将其保留为 Authorization 或设置为自定义值',
types: { types: {
none: '无', none: '无',
api_key: 'API Key', api_key: 'API Key',
...@@ -60,6 +61,14 @@ const translation = { ...@@ -60,6 +61,14 @@ const translation = {
key: '键', key: '键',
value: '值', value: '值',
}, },
authHeaderPrefix: {
title: '鉴权头部前缀',
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: '隐私协议', privacyPolicy: '隐私协议',
privacyPolicyPlaceholder: '请输入隐私协议', privacyPolicyPlaceholder: '请输入隐私协议',
}, },
......
{ {
"name": "dify-web", "name": "dify-web",
"version": "0.5.6", "version": "0.5.7",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "next dev", "dev": "next dev",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment