Commit 5ff6284e authored by StyleZhang's avatar StyleZhang

Merge branch 'fix/model-parameter-load-preset-config' into deploy/dev

parents 55c161c7 13116fe3
......@@ -90,7 +90,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.6"
self.CURRENT_VERSION = "0.5.7"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
......
from extensions.ext_database import db
from models.model import EndUser
def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
import json
from flask import current_app
from flask_restful import fields, marshal_with
from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider
class AppParameterApi(AppApiResource):
class AppParameterApi(Resource):
"""Resource for app variables."""
variable_fields = {
......@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields)
}
@validate_app_token
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
def get(self, app_model: App):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
......@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
}
}
class AppMetaApi(AppApiResource):
def get(self, app_model: App, end_user):
class AppMetaApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config
......
import logging
from flask import request
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
......@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig
from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
......@@ -30,8 +30,9 @@ from services.errors.audio import (
)
class AudioApi(AppApiResource):
def post(self, app_model: App, end_user):
class AudioApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
......@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError()
class TextApi(AppApiResource):
def post(self, app_model: App, end_user):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
......@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=args['text'],
end_user=args['user'],
end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)
......
......@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
......@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService
class CompletionApi(AppApiResource):
def post(self, app_model, end_user):
class CompletionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion':
raise AppUnavailableError()
......@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False
try:
......@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError()
class CompletionStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion':
raise AppUnavailableError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
class ChatApi(AppApiResource):
def post(self, app_model, end_user):
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
......@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
......@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
response = CompletionService.completion(
app_model=app_model,
......@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError()
class ChatStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
......
from flask import request
from flask_restful import marshal_with, reqparse
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService
class ConversationApi(AppApiResource):
class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource):
class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id):
def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try:
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
......@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource):
class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
......@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.rename(
app_model,
......
from flask import request
from flask_restful import marshal_with
from flask_restful import Resource, marshal_with
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService
class FileApi(AppApiResource):
class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model, end_user):
def post(self, app_model: App, end_user: EndUser):
file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file
if 'file' not in request.files:
......
from flask_restful import fields, marshal_with, reqparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from extensions.ext_database import db
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message
from models.model import App, EndUser
from services.message_service import MessageService
class MessageListApi(AppApiResource):
class MessageListApi(Resource):
feedback_fields = {
'rating': fields.String
}
......@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields))
}
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
......@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit'])
......@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(AppApiResource):
def post(self, app_model, end_user, message_id):
class MessageFeedbackApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError:
......@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
user=end_user,
message_id=message_id,
check_enabled=False
)
......
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
def validate_app_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
class WhereisUserArg(Enum):
"""
Enum for whereis_user_arg.
"""
QUERY = 'query'
JSON = 'json'
FORM = 'form'
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
......@@ -29,15 +47,34 @@ def validate_app_token(view=None):
if not app_model.enable_api:
raise NotFound()
return view(app_model, None, *args, **kwargs)
return decorated
kwargs['app_model'] = app_model
if view:
return decorator(view)
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")
if user_id:
user_id = str(user_id)
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def cloud_edition_billing_resource_check(resource: str,
......@@ -131,8 +168,33 @@ def validate_and_get_api_token(scope=None):
return api_token
class AppApiResource(Resource):
method_decorators = [validate_app_token]
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
"""
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource):
......
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param model_config:
:param messages:
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass
This diff is collapsed.
This diff is collapsed.
import json
import logging
from typing import cast
......@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
......@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
message_chain = self._init_message_chain(
message=message,
query=query
)
# init model instance
model_instance = ModelInstance(
......@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables
})
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""
......
......@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
......
......@@ -175,7 +175,7 @@ class GenerateTaskPipeline:
'id': self._message.id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'answer': event.llm_result.message.content,
'answer': self._task_state.llm_result.message.content,
'metadata': {},
'created_at': int(self._message.created_at.timestamp())
}
......
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
import logging
from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (
AgentEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool
\ No newline at end of file
......@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
......
......@@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
......
......@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......
import enum
import logging
from typing import Optional, Union
......@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
......@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigEntity
......@@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
......
......@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
......
......@@ -4,7 +4,6 @@ from typing import Any, Optional
import requests
from flask import current_app
from flask_login import current_user
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
......@@ -43,7 +42,7 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id)
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
......
import base64
import hashlib
import hmac
import json
import queue
import ssl
from datetime import datetime
from time import mktime
from typing import Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket
class SparkLLMClient:
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
domain = 'spark-api.xf-yun.com'
endpoint = 'chat'
if api_domain:
domain = api_domain
if model_name == 'spark-v3':
endpoint = 'multimodal'
model_api_configs = {
'spark': {
'version': 'v1.1',
'chat_domain': 'general'
},
'spark-v2': {
'version': 'v2.1',
'chat_domain': 'generalv2'
},
'spark-v3': {
'version': 'v3.1',
'chat_domain': 'generalv3'
},
'spark-v3.5': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
}
}
api_version = model_api_configs[model_name]['version']
self.chat_domain = model_api_configs[model_name]['chat_domain']
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,
urlparse(self.api_base).path,
self.api_base,
api_key,
api_secret
)
self.queue = queue.Queue()
self.blocking_message = ''
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
# generate timestamp by RFC1123
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + path + " HTTP/1.1"
# encrypt using hmac-sha256
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
}
# generate url
url = api_base + '?' + urlencode(v)
return url
def run(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None, streaming: bool = False):
websocket.enableTrace(False)
ws = websocket.WebSocketApp(
self.ws_url,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
on_open=self.on_open
)
ws.messages = messages
ws.user_id = user_id
ws.model_kwargs = model_kwargs
ws.streaming = streaming
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_error(self, ws, error):
self.queue.put({
'status_code': error.status_code,
'error': error.resp_body.decode('utf-8')
})
ws.close()
def on_close(self, ws, close_status_code, close_reason):
self.queue.put({'done': True})
def on_open(self, ws):
self.blocking_message = ''
data = json.dumps(self.gen_params(
messages=ws.messages,
user_id=ws.user_id,
model_kwargs=ws.model_kwargs
))
ws.send(data)
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
self.queue.put({
'status_code': 400,
'error': f"Code: {code}, Error: {data['header']['message']}"
})
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
if ws.streaming:
self.queue.put({'data': content})
else:
self.blocking_message += content
if status == 2:
if not ws.streaming:
self.queue.put({'data': self.blocking_message})
ws.close()
def gen_params(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None) -> dict:
data = {
"header": {
"app_id": self.app_id,
"uid": user_id
},
"parameter": {
"chat": {
"domain": self.chat_domain
}
},
"payload": {
"message": {
"text": messages
}
}
}
if model_kwargs:
data['parameter']['chat'].update(model_kwargs)
return data
def subscribe(self):
while True:
content = self.queue.get()
if 'error' in content:
if content['status_code'] == 401:
raise SparkError('[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.')
elif content['status_code'] == 403:
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions.")
else:
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
if 'data' not in content:
break
yield content
class SparkError(Exception):
pass
from datetime import datetime
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class DatetimeToolInput(BaseModel):
type: str = Field(..., description="Type for current time, must be: datetime.")
class DatetimeTool(BaseTool):
"""Tool for querying current datetime."""
name: str = "current_datetime"
args_schema: type[BaseModel] = DatetimeToolInput
description: str = "A tool when you want to get the current date, time, week, month or year, " \
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
def _run(self, type: str) -> str:
# get current time
current_time = datetime.utcnow()
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and tool_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.tool_name == self.get_provider_name().value
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]
class ToolValidateFailedError(Exception):
description = "Tool Provider Validate failed"
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from models.tool import ToolProviderName
class SerpAPIToolProvider(BaseToolProvider):
def get_provider_name(self) -> ToolProviderName:
"""
Returns the name of the provider.
:return:
"""
return ToolProviderName.SERPAPI
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider = self.get_provider(must_enabled=True)
if not tool_provider:
return None
credentials = tool_provider.credentials
if not credentials:
return None
if credentials.get('api_key'):
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
return credentials
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials = self.get_credentials()
if not credentials:
return None
return {
'serpapi_api_key': credentials.get('api_key')
}
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ToolValidateFailedError("SerpAPI api_key is required.")
api_key = credentials.get('api_key')
try:
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
except Exception as e:
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
return credentials
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
class ToolProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self._init_provider(tenant_id, provider_name)
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
if provider_name == 'serpapi':
return SerpAPIToolProvider(tenant_id)
else:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for Tool as a dictionary.
:param obfuscated:
:return:
"""
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
def encrypt_credentials(self, credentials: dict):
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
return self.provider.encrypt_credentials(credentials)
from langchain import SerpAPIWrapper
from pydantic import BaseModel, Field
class OptimizedSerpAPIInput(BaseModel):
query: str = Field(..., description="search query.")
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
@staticmethod
def _process_response(res: dict, num_results: int = 5) -> str:
"""Process response from SerpAPI."""
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
res["answer_box"] = res["answer_box"][0]
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
toret = res["sports_results"]["game_spotlight"]
elif (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
toret = res["shopping_results"][:3]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
toret = res["knowledge_graph"]["description"]
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
toret = ""
for result in res["organic_results"][:num_results]:
if "link" in result:
toret += "----------------\nlink: " + result["link"] + "\n"
if "snippet" in result:
toret += "snippet: " + result["snippet"] + "\n"
else:
toret = "No good search result found"
return "search result:\n" + toret
This diff is collapsed.
......@@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController):
en_US='The api key',
zh_Hans='api key的值'
)
),
'api_key_header_prefix': ToolProviderCredentials(
name='api_key_header_prefix',
required=False,
default='basic',
type=ToolProviderCredentials.CredentialsType.SELECT,
help=I18nObject(
en_US='The prefix of the api key header',
zh_Hans='api key header 的前缀'
),
options=[
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
]
)
}
elif auth_type == ApiProviderAuthType.NONE:
......
......@@ -62,6 +62,17 @@ class ApiTool(Tool):
if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value')
elif not isinstance(credentials['api_key_value'], str):
raise ToolProviderCredentialValidationError('api_key_value must be a string')
if 'api_key_header_prefix' in credentials:
api_key_header_prefix = credentials['api_key_header_prefix']
if api_key_header_prefix == 'basic':
credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}'
elif api_key_header_prefix == 'bearer':
credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
elif api_key_header_prefix == 'custom':
pass
headers[api_key_header] = credentials['api_key_value']
......
......@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool
......@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
@staticmethod
def get_dataset_tools(tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']:
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']:
"""
get dataset tool
"""
......@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
)
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools
tools = []
for langchain_tool in langchain_tools:
......@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
llm=langchain_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
tools.append(tool)
return tools
......@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
def get_runtime_parameters(self) -> list[ToolParameter]:
return [
ToolParameter(name='query',
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
]
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
......@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
query = tool_parameters.get('query', None)
if not query:
return self.create_text_message(text='please input query')
# invoke dataset retriever tool
result = self.langchain_tool._run(query=query)
......@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
"""
validate the credentials for dataset retriever tool
"""
pass
\ No newline at end of file
pass
......@@ -7,23 +7,14 @@ import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from typing import Any
import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """
TITLE: {title}
......@@ -36,106 +27,6 @@ TEXT:
"""
class WebReaderToolInput(BaseModel):
url: str = Field(..., description="URL of the website to read")
summary: bool = Field(
default=False,
description="When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor: int = Field(
default=0,
description="Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
"The value cannot exceed 24000.",
)
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "web_reader"
args_schema: type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
"there is no need to use."
page_contents: str = None
url: str = None
max_chunk_length: int = 4000
summary_chunk_tokens: int = 4000
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
continue_reading: bool = True
model_config: ModelConfigEntity
model_parameters: dict[str, Any]
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
if not self.page_contents or self.url != url:
page_contents = get_url(url)
self.page_contents = page_contents
self.url = url
else:
page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators
)
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
return "No content found."
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]
chain = self.get_summary_chain()
try:
page_contents = chain.run(docs)
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
async def _arun(self, url: str) -> str:
raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.PROMPT,
parameters=self.model_parameters
)
refine_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.REFINE_PROMPT,
parameters=self.model_parameters
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length]
......
import re
import uuid
from core.agent.agent_executor import PlanningStrategy
from core.entities.agent_entities import PlanningStrategy
from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory
......
......@@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.5.6
image: langgenius/dify-api:0.5.7
restart: always
environment:
# Startup mode, 'api' starts the API server.
......@@ -135,7 +135,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.5.6
image: langgenius/dify-api:0.5.7
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
......@@ -206,7 +206,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.5.6
image: langgenius/dify-web:0.5.7
restart: always
environment:
EDITION: SELF_HOSTED
......
......@@ -3,16 +3,13 @@ import React from 'react'
import Spinner from '../spinner'
export type IButtonProps = {
/**
* The style of the button
*/
type?: 'primary' | 'warning' | (string & {})
type?: string
className?: string
disabled?: boolean
loading?: boolean
tabIndex?: number
children: React.ReactNode
onClick?: MouseEventHandler<HTMLButtonElement>
onClick?: MouseEventHandler<HTMLDivElement>
}
const Button: FC<IButtonProps> = ({
......@@ -38,16 +35,15 @@ const Button: FC<IButtonProps> = ({
}
return (
<button
<div
className={`btn ${style} ${className && className}`}
tabIndex={tabIndex}
disabled={disabled}
onClick={onClick}
onClick={disabled ? undefined : onClick}
>
{children}
{/* Spinner is hidden when loading is false */}
<Spinner loading={loading} className='!text-white !h-3 !w-3 !border-2 !ml-1' />
</button>
</div>
)
}
......
......@@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran
</Col>
<Col sticky>
### Request Example
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \
curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
......
......@@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/v1/completion-messages/:task_id/stop' \
curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
......
......@@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to
</Col>
<Col sticky>
### Request Example
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{"user": "abc-123"}'`}>
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{"user": "abc-123"}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \
curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
......@@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to
- (string) url of icon
</Col>
<Col>
<CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
<CodeGroup title="Request" tag="GET" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
```bash {{ title: 'cURL' }}
curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \
curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}'
```
......
......@@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/v1/chat-messages/:task_id/stop' \
curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \
--data-raw '{
......@@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- (string) 图标URL
</Col>
<Col>
<CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
<CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
```bash {{ title: 'cURL' }}
curl -X GET '${props.appDetail.api_base_url}/v1/meta?user=abc-123' \
curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}'
```
......
......@@ -115,6 +115,11 @@ const ParameterItem: FC<ParameterItemProps> = ({
}
}
useEffect(() => {
if ((parameterRule.type === 'int' || parameterRule.type === 'float') && numberInputRef.current)
numberInputRef.current.value = `${renderValue}`
}, [value])
const renderInput = () => {
const numberInputWithSlide = (parameterRule.type === 'int' || parameterRule.type === 'float')
&& !isNullOrUndefined(parameterRule.min)
......@@ -207,11 +212,6 @@ const ParameterItem: FC<ParameterItemProps> = ({
return null
}
useEffect(() => {
if (numberInputRef.current)
numberInputRef.current.value = `${renderValue}`
}, [])
return (
<div className={`flex items-center justify-between ${className}`}>
<div>
......
......@@ -3,11 +3,13 @@ import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import Tooltip from '../../base/tooltip'
import { HelpCircle } from '../../base/icons/src/vender/line/general'
import type { Credential } from '@/app/components/tools/types'
import Drawer from '@/app/components/base/drawer-plus'
import Button from '@/app/components/base/button'
import Radio from '@/app/components/base/radio/ui'
import { AuthType } from '@/app/components/tools/types'
import { AuthHeaderPrefix, AuthType } from '@/app/components/tools/types'
type Props = {
credential: Credential
......@@ -18,9 +20,9 @@ const keyClassNames = 'py-2 leading-5 text-sm font-medium text-gray-900'
type ItemProps = {
text: string
value: AuthType
value: AuthType | AuthHeaderPrefix
isChecked: boolean
onClick: (value: AuthType) => void
onClick: (value: AuthType | AuthHeaderPrefix) => void
}
const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => {
......@@ -31,7 +33,6 @@ const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => {
>
<Radio isChecked={isChecked} />
<div className='text-sm font-normal text-gray-900'>{text}</div>
</div>
)
}
......@@ -43,6 +44,7 @@ const ConfigCredential: FC<Props> = ({
}) => {
const { t } = useTranslation()
const [tempCredential, setTempCredential] = React.useState<Credential>(credential)
return (
<Drawer
isShow
......@@ -62,20 +64,59 @@ const ConfigCredential: FC<Props> = ({
text={t('tools.createTool.authMethod.types.none')}
value={AuthType.none}
isChecked={tempCredential.auth_type === AuthType.none}
onClick={value => setTempCredential({ ...tempCredential, auth_type: value })}
onClick={value => setTempCredential({ ...tempCredential, auth_type: value as AuthType })}
/>
<SelectItem
text={t('tools.createTool.authMethod.types.api_key')}
value={AuthType.apiKey}
isChecked={tempCredential.auth_type === AuthType.apiKey}
onClick={value => setTempCredential({ ...tempCredential, auth_type: value })}
onClick={value => setTempCredential({
...tempCredential,
auth_type: value as AuthType,
api_key_header: tempCredential.api_key_header || 'Authorization',
api_key_value: tempCredential.api_key_value || '',
api_key_header_prefix: tempCredential.api_key_header_prefix || AuthHeaderPrefix.custom,
})}
/>
</div>
</div>
{tempCredential.auth_type === AuthType.apiKey && (
<>
<div className={keyClassNames}>{t('tools.createTool.authHeaderPrefix.title')}</div>
<div className='flex space-x-3'>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.basic')}
value={AuthHeaderPrefix.basic}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.basic}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.bearer')}
value={AuthHeaderPrefix.bearer}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.bearer}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.custom')}
value={AuthHeaderPrefix.custom}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.custom}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
</div>
<div>
<div className={keyClassNames}>{t('tools.createTool.authMethod.key')}</div>
<div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
{t('tools.createTool.authMethod.key')}
<Tooltip
selector='model-page-system-reasoning-model-tip'
htmlContent={
<div className='w-[261px] text-gray-500'>
{t('tools.createTool.authMethod.keyTooltip')}
</div>
}
>
<HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400'/>
</Tooltip>
</div>
<input
value={tempCredential.api_key_header}
onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
......@@ -83,7 +124,6 @@ const ConfigCredential: FC<Props> = ({
placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!}
/>
</div>
<div>
<div className={keyClassNames}>{t('tools.createTool.authMethod.value')}</div>
<input
......
......@@ -8,7 +8,7 @@ import { clone } from 'lodash-es'
import cn from 'classnames'
import { LinkExternal02, Settings01 } from '../../base/icons/src/vender/line/general'
import type { Credential, CustomCollectionBackend, CustomParamSchema, Emoji } from '../types'
import { AuthType } from '../types'
import { AuthHeaderPrefix, AuthType } from '../types'
import GetSchema from './get-schema'
import ConfigCredentials from './config-credentials'
import TestApi from './test-api'
......@@ -37,6 +37,7 @@ const EditCustomCollectionModal: FC<Props> = ({
const { t } = useTranslation()
const isAdd = !payload
const isEdit = !!payload
const [editFirst, setEditFirst] = useState(!isAdd)
const [paramsSchemas, setParamsSchemas] = useState<CustomParamSchema[]>(payload?.tools || [])
const [customCollection, setCustomCollection, getCustomCollection] = useGetState<CustomCollectionBackend>(isAdd
......@@ -44,6 +45,8 @@ const EditCustomCollectionModal: FC<Props> = ({
provider: '',
credentials: {
auth_type: AuthType.none,
api_key_header: 'Authorization',
api_key_header_prefix: AuthHeaderPrefix.basic,
},
icon: {
content: '🕵️',
......
......@@ -3,7 +3,7 @@ import type { FC } from 'react'
import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import { CollectionType, LOC } from '../types'
import { AuthHeaderPrefix, AuthType, CollectionType, LOC } from '../types'
import type { Collection, CustomCollectionBackend, Tool } from '../types'
import Loading from '../../base/loading'
import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows'
......@@ -53,6 +53,10 @@ const ToolList: FC<Props> = ({
(async () => {
if (collection.type === CollectionType.custom) {
const res = await fetchCustomCollection(collection.name)
if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
if (res.credentials.api_key_value)
res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
}
setCustomCollection({
...res,
provider: collection.name,
......
......@@ -9,10 +9,17 @@ export enum AuthType {
apiKey = 'api_key',
}
export enum AuthHeaderPrefix {
basic = 'basic',
bearer = 'bearer',
custom = 'custom',
}
export type Credential = {
'auth_type': AuthType
'api_key_header'?: string
'api_key_value'?: string
'api_key_header_prefix'?: AuthHeaderPrefix
}
export enum CollectionType {
......
......@@ -51,6 +51,7 @@ const translation = {
authMethod: {
title: 'Authorization method',
type: 'Authorization type',
keyTooltip: 'Http Header Key, You can leave it with "Authorization" if you have no idea what it is or set it to a custom value',
types: {
none: 'None',
api_key: 'API Key',
......@@ -60,6 +61,14 @@ const translation = {
key: 'Key',
value: 'Value',
},
authHeaderPrefix: {
title: 'Auth Type',
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Privacy policy',
privacyPolicyPlaceholder: 'Please enter privacy policy',
},
......
......@@ -58,6 +58,13 @@ const translation = {
key: 'Chave',
value: 'Valor',
},
authHeaderPrefix: {
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Política de Privacidade',
privacyPolicyPlaceholder: 'Digite a política de privacidade',
},
......
......@@ -58,6 +58,13 @@ const translation = {
key: 'Ключ',
value: 'Значення',
},
authHeaderPrefix: {
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Політика конфіденційності',
privacyPolicyPlaceholder: 'Введіть політику конфіденційності',
},
......
......@@ -51,6 +51,7 @@ const translation = {
authMethod: {
title: '鉴权方法',
type: '鉴权类型',
keyTooltip: 'HTTP 头部名称,如果你不知道是什么,可以将其保留为 Authorization 或设置为自定义值',
types: {
none: '无',
api_key: 'API Key',
......@@ -60,6 +61,14 @@ const translation = {
key: '键',
value: '值',
},
authHeaderPrefix: {
title: '鉴权头部前缀',
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: '隐私协议',
privacyPolicyPlaceholder: '请输入隐私协议',
},
......
{
"name": "dify-web",
"version": "0.5.6",
"version": "0.5.7",
"private": true,
"scripts": {
"dev": "next dev",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment