Commit f7a90f26 authored by Joel's avatar Joel

merge main

parents 3d825dcb dd961985
...@@ -21,6 +21,17 @@ ...@@ -21,6 +21,17 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a> <img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p> </p>
<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>
<p align="center"> <p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank"> <a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
......
...@@ -6,15 +6,15 @@ import click ...@@ -6,15 +6,15 @@ import click
from flask import current_app from flask import current_app
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding from core.rag.datasource.vdb.vector_factory import Vector
from core.model_manager import ModelManager from core.rag.models.document import Document
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email as email_validate from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import Tenant from models.account import Tenant
from models.dataset import Dataset from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account from models.model import Account
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
...@@ -124,14 +124,15 @@ def reset_encrypt_key_pair(): ...@@ -124,14 +124,15 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.') @click.command('vdb-migrate', help='migrate vector db.')
def create_qdrant_indexes(): def vdb_migrate():
""" """
Migrate other vector database datas to Qdrant. Migrate vector database datas to target vector database .
""" """
click.echo(click.style('Start create qdrant indexes.', fg='green')) click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0 create_count = 0
config = current_app.config
vector_type = config.get('VECTOR_STORE')
page = 1 page = 1
while True: while True:
try: try:
...@@ -140,54 +141,101 @@ def create_qdrant_indexes(): ...@@ -140,54 +141,101 @@ def create_qdrant_indexes():
except NotFound: except NotFound:
break break
model_manager = ModelManager()
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
if dataset.index_struct_dict: try:
if dataset.index_struct_dict['type'] != 'qdrant': click.echo('Create dataset vdb index: {}'.format(dataset.id))
try: if dataset.index_struct_dict:
click.echo('Create dataset qdrant index: {}'.format(dataset.id)) if dataset.index_struct_dict['type'] == vector_type:
try: continue
embedding_model = model_manager.get_model_instance( if vector_type == "weaviate":
tenant_id=dataset.tenant_id, dataset_id = dataset.id
provider=dataset.embedding_model_provider, collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
model_type=ModelType.TEXT_EMBEDDING, index_struct_dict = {
model=dataset.embedding_model "type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
) }
except Exception: dataset.index_struct = json.dumps(index_struct_dict)
continue elif vector_type == "qdrant":
embeddings = CacheEmbedding(embedding_model) if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
index = QdrantVectorIndex( if dataset_collection_binding:
dataset=dataset, collection_name = dataset_collection_binding.collection_name
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
else: else:
click.echo('passed.') raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
try:
vector.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
vector.create(documents)
except Exception as e: except Exception as e:
click.echo( raise e
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.echo(f"Dataset {dataset.id} create successfully.")
fg='red')) db.session.add(dataset)
continue db.session.commit()
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
...@@ -196,4 +244,4 @@ def register_commands(app): ...@@ -196,4 +244,4 @@ def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(create_qdrant_indexes) app.cli.add_command(vdb_migrate)
...@@ -38,7 +38,9 @@ DEFAULTS = { ...@@ -38,7 +38,9 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False', 'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
'HOSTED_OPENAI_PAID_ENABLED': 'False', 'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
...@@ -88,7 +90,7 @@ class Config: ...@@ -88,7 +90,7 @@ class Config:
# ------------------------ # ------------------------
# General Configurations. # General Configurations.
# ------------------------ # ------------------------
self.CURRENT_VERSION = "0.5.6" self.CURRENT_VERSION = "0.5.7"
self.COMMIT_SHA = get_env('COMMIT_SHA') self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED" self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV') self.DEPLOY_ENV = get_env('DEPLOY_ENV')
...@@ -261,8 +263,10 @@ class Config: ...@@ -261,8 +263,10 @@ class Config:
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED') self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
......
...@@ -124,19 +124,13 @@ class AppListApi(Resource): ...@@ -124,19 +124,13 @@ class AppListApi(Resource):
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models] available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}" provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
if provider_model not in available_models_names: if provider_model not in available_models_names:
model_manager = ModelManager() if not default_model_entity:
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_instance:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Default System Reasoning Model available. Please configure " "No Default System Reasoning Model available. Please configure "
"in the Settings -> Model Provider.") "in the Settings -> Model Provider.")
else: else:
model_config_dict["model"]["provider"] = model_instance.provider model_config_dict["model"]["provider"] = default_model_entity.provider
model_config_dict["model"]["name"] = model_instance.model model_config_dict["model"]["name"] = default_model_entity.model
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
......
...@@ -178,7 +178,8 @@ class DataSourceNotionApi(Resource): ...@@ -178,7 +178,8 @@ class DataSourceNotionApi(Resource):
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=data_source_binding.access_token notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id
) )
text_docs = extractor.extract() text_docs = extractor.extract()
...@@ -208,7 +209,8 @@ class DataSourceNotionApi(Resource): ...@@ -208,7 +209,8 @@ class DataSourceNotionApi(Resource):
notion_info={ notion_info={
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'], "notion_obj_id": page['page_id'],
"notion_page_type": page['type'] "notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
}, },
document_model=args['doc_form'] document_model=args['doc_form']
) )
......
...@@ -298,7 +298,8 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -298,7 +298,8 @@ class DatasetIndexingEstimateApi(Resource):
notion_info={ notion_info={
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'], "notion_obj_id": page['page_id'],
"notion_page_type": page['type'] "notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
}, },
document_model=args['doc_form'] document_model=args['doc_form']
) )
......
...@@ -455,7 +455,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -455,7 +455,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
notion_info={ notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'], "notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'], "notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type'] "notion_page_type": data_source_info['type'],
"tenant_id": current_user.current_tenant_id
}, },
document_model=document.doc_form document_model=document.doc_form
) )
......
from extensions.ext_database import db
from models.model import EndUser
def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
import json import json
from flask import current_app from flask import current_app
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig from models.model import App, AppModelConfig
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
class AppParameterApi(AppApiResource): class AppParameterApi(Resource):
"""Resource for app variables.""" """Resource for app variables."""
variable_fields = { variable_fields = {
...@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource): ...@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields) 'system_parameters': fields.Nested(system_parameters_fields)
} }
@validate_app_token
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
...@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource): ...@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
} }
} }
class AppMetaApi(AppApiResource): class AppMetaApi(Resource):
def get(self, app_model: App, end_user): @validate_app_token
def get(self, app_model: App):
"""Get app meta""" """Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
......
import logging import logging
from flask import request from flask import request
from flask_restful import reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
...@@ -17,10 +17,10 @@ from controllers.service_api.app.error import ( ...@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
UnsupportedAudioTypeError, UnsupportedAudioTypeError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
...@@ -30,8 +30,9 @@ from services.errors.audio import ( ...@@ -30,8 +30,9 @@ from services.errors.audio import (
) )
class AudioApi(AppApiResource): class AudioApi(Resource):
def post(self, app_model: App, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']: if not app_model_config.speech_to_text_dict['enabled']:
...@@ -73,11 +74,11 @@ class AudioApi(AppApiResource): ...@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class TextApi(AppApiResource): class TextApi(Resource):
def post(self, app_model: App, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json') parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
...@@ -85,7 +86,7 @@ class TextApi(AppApiResource): ...@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
text=args['text'], text=args['text'],
end_user=args['user'], end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'), voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming'] streaming=args['streaming']
) )
......
...@@ -4,12 +4,11 @@ from collections.abc import Generator ...@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union from typing import Union
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
...@@ -19,17 +18,19 @@ from controllers.service_api.app.error import ( ...@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError, ProviderNotInitializeError,
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService from services.completion_service import CompletionService
class CompletionApi(AppApiResource): class CompletionApi(Resource):
def post(self, app_model, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
...@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource): ...@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False args['auto_generate_name'] = False
try: try:
...@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource): ...@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class CompletionStopApi(AppApiResource): class CompletionStopApi(Resource):
def post(self, app_model, end_user, task_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion': if app_model.mode != 'completion':
raise AppUnavailableError() raise AppUnavailableError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
class ChatApi(AppApiResource): class ChatApi(Resource):
def post(self, app_model, end_user): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -114,7 +102,6 @@ class ChatApi(AppApiResource): ...@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
...@@ -122,9 +109,6 @@ class ChatApi(AppApiResource): ...@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
response = CompletionService.completion( response = CompletionService.completion(
app_model=app_model, app_model=app_model,
...@@ -157,22 +141,12 @@ class ChatApi(AppApiResource): ...@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError() raise InternalServerError()
class ChatStopApi(AppApiResource): class ChatStopApi(Resource):
def post(self, app_model, end_user, task_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
......
from flask import request from flask_restful import Resource, marshal_with, reqparse
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
class ConversationApi(AppApiResource): class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError: except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource): class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try: try:
ConversationService.delete(app_model, conversation_id, end_user) ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
...@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource): ...@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204 return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource): class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource): ...@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json') parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return ConversationService.rename( return ConversationService.rename(
app_model, app_model,
......
from flask import request from flask import request
from flask_restful import marshal_with from flask_restful import Resource, marshal_with
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
FileTooLargeError, FileTooLargeError,
NoFileUploadedError, NoFileUploadedError,
TooManyFilesError, TooManyFilesError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService from services.file_service import FileService
class FileApi(AppApiResource): class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields) @marshal_with(file_fields)
def post(self, app_model, end_user): def post(self, app_model: App, end_user: EndUser):
file = request.files['file'] file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file # check file
if 'file' not in request.files: if 'file' not in request.files:
......
from flask_restful import fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message from models.model import App, EndUser
from services.message_service import MessageService from services.message_service import MessageService
class MessageListApi(AppApiResource): class MessageListApi(Resource):
feedback_fields = { feedback_fields = {
'rating': fields.String 'rating': fields.String
} }
...@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource): ...@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields)) 'data': fields.List(fields.Nested(message_fields))
} }
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
...@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource): ...@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
return MessageService.pagination_by_first_id(app_model, end_user, return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit']) args['conversation_id'], args['first_id'], args['limit'])
...@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource): ...@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.") raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(AppApiResource): class MessageFeedbackApi(Resource):
def post(self, app_model, end_user, message_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try: try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating']) MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
...@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource): ...@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'} return {'result': 'success'}
class MessageSuggestedApi(AppApiResource): class MessageSuggestedApi(Resource):
def get(self, app_model, end_user, message_id): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
try: try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, app_model=app_model,
user=user, user=end_user,
message_id=message_id, message_id=message_id,
check_enabled=False check_enabled=False
) )
......
from collections.abc import Callable
from datetime import datetime from datetime import datetime
from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_restful import Resource from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import _get_user from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService
def validate_app_token(view=None): class WhereisUserArg(Enum):
def decorator(view): """
@wraps(view) Enum for whereis_user_arg.
def decorated(*args, **kwargs): """
QUERY = 'query'
JSON = 'json'
FORM = 'form'
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app') api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first() app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
...@@ -29,15 +47,34 @@ def validate_app_token(view=None): ...@@ -29,15 +47,34 @@ def validate_app_token(view=None):
if not app_model.enable_api: if not app_model.enable_api:
raise NotFound() raise NotFound()
return view(app_model, None, *args, **kwargs) kwargs['app_model'] = app_model
return decorated
if view: if fetch_user_arg:
return decorator(view) if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None
# if view is None, it means that the decorator is used without parentheses if not user_id and fetch_user_arg.required:
# use the decorator as a function for method_decorators raise ValueError("Arg user must be provided.")
return decorator
if user_id:
user_id = str(user_id)
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def cloud_edition_billing_resource_check(resource: str, def cloud_edition_billing_resource_check(resource: str,
...@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None): ...@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
return api_token return api_token
class AppApiResource(Resource): def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
method_decorators = [validate_app_token] """
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource): class DatasetApiResource(Resource):
......
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param model_config:
:param messages:
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
model_config: ModelConfigEntity
agent_llm_callback: Optional[AgentLLMCallback] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = 0
for parameter_rule in self.model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
self.model_config.parameters['max_tokens'] = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
except Exception as e:
raise e
self.model_config.parameters['max_tokens'] = original_max_tokens
return True if result.message.tool_calls else False
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
# summarize messages if rest_tokens < 0
try:
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.message.content or "",
additional_kwargs={
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
@classmethod
def get_system_message(cls):
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(
self.model_config,
messages,
**kwargs
)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_config.provider == 'azure_openai':
model = model_config.model
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_config.credentials.get("base_model_name")
tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
import re
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
OutputParserException,
get_buffer_string,
)
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observatons
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
raise e
try:
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
self.moving_summary_index = len(intermediate_steps)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
if 'chat_history' in kwargs:
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)
import json
import logging import logging
from typing import cast from typing import cast
...@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner): ...@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables # convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
message_chain = self._init_message_chain(
message=message,
query=query
)
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
...@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner): ...@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables 'pool': db_variables.variables
}) })
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage: message: Message) -> LLMUsage:
""" """
......
...@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner ...@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
......
...@@ -175,7 +175,7 @@ class GenerateTaskPipeline: ...@@ -175,7 +175,7 @@ class GenerateTaskPipeline:
'id': self._message.id, 'id': self._message.id,
'message_id': self._message.id, 'message_id': self._message.id,
'mode': self._conversation.mode, 'mode': self._conversation.mode,
'answer': event.llm_result.message.content, 'answer': self._task_state.llm_result.message.content,
'metadata': {}, 'metadata': {},
'created_at': int(self._message.created_at.timestamp()) 'created_at': int(self._message.created_at.timestamp())
} }
......
...@@ -3,12 +3,12 @@ import logging ...@@ -3,12 +3,12 @@ import logging
from typing import Optional, cast from typing import Optional, cast
import numpy as np import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.entity.embedding import Embeddings
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs import helper from libs import helper
......
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
import logging
from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (
AgentEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool
\ No newline at end of file
...@@ -59,7 +59,7 @@ class AnnotationReplyFeature: ...@@ -59,7 +59,7 @@ class AnnotationReplyFeature:
documents = vector.search_by_vector( documents = vector.search_by_vector(
query=query, query=query,
k=1, top_k=1,
score_threshold=score_threshold, score_threshold=score_threshold,
filter={ filter={
'group_id': [dataset.id] 'group_id': [dataset.id]
......
...@@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner):
for message in messages: for message in messages:
result.append(UserPromptMessage(content=message.query)) result.append(UserPromptMessage(content=message.query))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
for agent_thought in agent_thoughts: if agent_thoughts:
tools = agent_thought.tool for agent_thought in agent_thoughts:
if tools: tools = agent_thought.tool
tools = tools.split(';') if tools:
tool_calls: list[AssistantPromptMessage.ToolCall] = [] tools = tools.split(';')
tool_call_response: list[ToolPromptMessage] = [] tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_inputs = json.loads(agent_thought.tool_input) tool_call_response: list[ToolPromptMessage] = []
for tool in tools: tool_inputs = json.loads(agent_thought.tool_input)
# generate a uuid for tool call for tool in tools:
tool_call_id = str(uuid.uuid4()) # generate a uuid for tool call
tool_calls.append(AssistantPromptMessage.ToolCall( tool_call_id = str(uuid.uuid4())
id=tool_call_id, tool_calls.append(AssistantPromptMessage.ToolCall(
type='function', id=tool_call_id,
function=AssistantPromptMessage.ToolCall.ToolCallFunction( type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool, name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})), tool_call_id=tool_call_id,
) ))
))
tool_call_response.append(ToolPromptMessage( result.extend([
content=agent_thought.observation, AssistantPromptMessage(
name=tool, content=agent_thought.thought,
tool_call_id=tool_call_id, tool_calls=tool_calls,
)) ),
*tool_call_response
result.extend([ ])
AssistantPromptMessage( if not tools:
content=agent_thought.thought, result.append(AssistantPromptMessage(content=agent_thought.thought))
tool_calls=tool_calls, else:
), if message.answer:
*tool_call_response result.append(AssistantPromptMessage(content=message.answer))
])
return result return result
\ No newline at end of file
...@@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought='', thought='',
action_str='', action_str='',
observation='', observation='',
action=None action=None,
) )
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
...@@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought=message.content, thought=message.content,
action_str='', action_str='',
action=None, action=None,
observation=None observation=None,
) )
if message.tool_calls: if message.tool_calls:
try: try:
...@@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
if current_scratchpad: if current_scratchpad:
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
return agent_scratchpad return agent_scratchpad
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
...@@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
prompt_message.content = system_message prompt_message.content = system_message
overridden = True overridden = True
break break
# convert tool prompt messages to user prompt messages
for idx, prompt_message in enumerate(prompt_messages):
if isinstance(prompt_message, ToolPromptMessage):
prompt_messages[idx] = UserPromptMessage(
content=prompt_message.content
)
if not overridden: if not overridden:
prompt_messages.insert(0, SystemPromptMessage( prompt_messages.insert(0, SystemPromptMessage(
......
...@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun ...@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain): class LLMChain(LCLLMChain):
......
...@@ -12,9 +12,9 @@ from pydantic import root_validator ...@@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent): class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
......
...@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy ...@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......
import enum
import logging import logging
from typing import Optional, Union from typing import Optional, Union
...@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks ...@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.entities.agent_entities import PlanningStrategy
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
...@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas ...@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
class AgentConfiguration(BaseModel): class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
model_config: ModelConfigEntity model_config: ModelConfigEntity
...@@ -62,28 +53,7 @@ class AgentExecutor: ...@@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent() self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT: if self.configuration.strategy == PlanningStrategy.ROUTER:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool) if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)] or isinstance(t, DatasetMultiRetrieverTool)]
......
...@@ -2,9 +2,10 @@ from typing import Optional, cast ...@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool from langchain.tools import BaseTool
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
......
...@@ -104,37 +104,17 @@ class HostingConfiguration: ...@@ -104,37 +104,17 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota( trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit, quota_limit=hosted_quota_limit,
restrict_models=[ restrict_models=trial_models
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
) )
quotas.append(trial_quota) quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota( paid_quota = PaidHostingQuota(
restrict_models=[ restrict_models=paid_models
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
) )
quotas.append(paid_quota) quotas.append(paid_quota)
...@@ -258,3 +238,11 @@ class HostingConfiguration: ...@@ -258,3 +238,11 @@ class HostingConfiguration:
return HostedModerationConfig( return HostedModerationConfig(
enabled=False enabled=False
) )
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
models_list = models_str.split(",") if models_str else []
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
model_name.strip()]
...@@ -365,8 +365,9 @@ class IndexingRunner: ...@@ -365,8 +365,9 @@ class IndexingRunner:
notion_info={ notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'], "notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'], "notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['notion_page_type'], "notion_page_type": data_source_info['type'],
"document": dataset_document "document": dataset_document,
"tenant_id": dataset_document.tenant_id
}, },
document_model=dataset_document.doc_form document_model=dataset_document.doc_form
) )
...@@ -664,6 +665,7 @@ class IndexingRunner: ...@@ -664,6 +665,7 @@ class IndexingRunner:
) )
# load index # load index
index_processor.load(dataset, chunk_documents) index_processor.load(dataset, chunk_documents)
db.session.add(dataset)
document_ids = [document.metadata['doc_id'] for document in chunk_documents] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md) ​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)
- 可选择的模型列表展示 - 可选择的模型列表展示
...@@ -86,4 +86,4 @@ Model Runtime 分三层: ...@@ -86,4 +86,4 @@ Model Runtime 分三层:
![Alt text](docs/zh_Hans/images/index/image-2.png) ![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) ### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
\ No newline at end of file
...@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { ...@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
'min': 1, 'min': 1,
'max': 2048, 'max': 2048,
'precision': 0, 'precision': 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
},
'required': False,
'options': ['JSON', 'XML'],
} }
} }
\ No newline at end of file
...@@ -91,6 +91,7 @@ class DefaultParameterName(Enum): ...@@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
PRESENCE_PENALTY = "presence_penalty" PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty" FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens" MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
@classmethod @classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName': def value_of(cls, value: Any) -> 'DefaultParameterName':
......
...@@ -262,23 +262,23 @@ class AIModel(ABC): ...@@ -262,23 +262,23 @@ class AIModel(ABC):
try: try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max: if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max'] parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min: if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min'] parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.precision: if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default'] parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision: if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision'] parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required: if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required'] parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help: if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject( parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'], en_US=default_parameter_rule['help']['en_US'],
) )
if not parameter_rule.help.en_US: if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans: if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError: except ValueError:
pass pass
......
...@@ -9,7 +9,13 @@ from typing import Optional, Union ...@@ -9,7 +9,13 @@ from typing import Optional, Union
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.callbacks.logging_callback import LoggingCallback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
ModelPropertyKey, ModelPropertyKey,
ModelType, ModelType,
...@@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel): ...@@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel):
) )
try: try:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) if "response_format" in model_parameters:
result = self._code_block_mode_wrapper(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
else:
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
except Exception as e: except Exception as e:
self._trigger_invoke_error_callbacks( self._trigger_invoke_error_callbacks(
model=model, model=model,
...@@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel): ...@@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel):
return result return result
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper, ensure the response is a code block with output markdown quote
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
first_chunk = next(response)
def new_generator():
yield first_chunk
yield from response
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
)
else:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=new_generator()
)
return response
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
input_generator: Generator[LLMResultChunk, None, None]
) -> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "normal"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
piece.delta.message.content = ""
yield piece
piece = content
else:
yield piece
continue
new_piece = ""
for char in piece:
if state == "normal":
if char == "`":
state = "in_backticks"
backtick_count = 1
else:
new_piece += char
elif state == "in_backticks":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_content"
backtick_count = 0
else:
new_piece += "`" * backtick_count + char
state = "normal"
backtick_count = 0
elif state == "skip_content":
if char.isspace():
state = "normal"
if new_piece:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
)
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
input_generator: Generator[LLMResultChunk, None, None]) \
-> Generator[LLMResultChunk, None, None]:
"""
Code block mode stream processor, ensure the response is a code block with output markdown quote.
This version skips the language identifier that follows the opening triple backticks.
:param model: model name
:param prompt_messages: prompt messages
:param input_generator: input generator
:return: output generator
"""
state = "search_start"
backtick_count = 0
for piece in input_generator:
if piece.delta.message.content:
content = piece.delta.message.content
# Reset content to ensure we're only processing and yielding the relevant parts
piece.delta.message.content = ""
# Yield a piece with cleared content before processing it to maintain the generator structure
yield piece
piece = content
else:
# Yield pieces without content directly
yield piece
continue
if state == "done":
continue
new_piece = ""
for char in piece:
if state == "search_start":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "skip_language"
backtick_count = 0
else:
backtick_count = 0
elif state == "skip_language":
# Skip everything until the first newline, marking the end of the language identifier
if char == "\n":
state = "in_code_block"
elif state == "in_code_block":
if char == "`":
backtick_count += 1
if backtick_count == 3:
state = "done"
break
else:
if backtick_count > 0:
# If backticks were counted but we're still collecting content, it was a false start
new_piece += "`" * backtick_count
backtick_count = 0
new_piece += char
elif state == "done":
break
if new_piece:
# Only yield content collected within the code block
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=new_piece,
tool_calls=[]
),
)
)
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
...@@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel): ...@@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel):
:return: full response or stream response chunk generator result :return: full response or stream response chunk generator result
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
- bedrock - bedrock
- togetherai - togetherai
- ollama - ollama
- mistralai
- replicate - replicate
- huggingface_hub - huggingface_hub
- zhipuai - zhipuai
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '8.00' input: '8.00'
output: '24.00' output: '24.00'
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '8.00' input: '8.00'
output: '24.00' output: '24.00'
......
...@@ -26,6 +26,8 @@ parameter_rules: ...@@ -26,6 +26,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '1.63' input: '1.63'
output: '5.51' output: '5.51'
......
...@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream ...@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params from anthropic.types import Completion, completion_create_params
from httpx import Timeout from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import ( ...@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class AnthropicLargeLanguageModel(LargeLanguageModel): <instructions>
{{instructions}}
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
...@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
self._transform_json_prompts(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(
content=f"```{response_format}\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 2048 default: 2048
min: 1 min: 1
max: 2048 max: 2048
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'
......
...@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
...@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ...@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
......
- open-mistral-7b
- open-mixtral-8x7b
- mistral-small-latest
- mistral-medium-latest
- mistral-large-latest
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# mistral dose not support user/stop arguments
stop = []
user = None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
model: mistral-large-latest
label:
zh_Hans: mistral-large-latest
en_US: mistral-large-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD
model: mistral-medium-latest
label:
zh_Hans: mistral-medium-latest
en_US: mistral-medium-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.0027'
output: '0.0081'
unit: '0.001'
currency: USD
model: mistral-small-latest
label:
zh_Hans: mistral-small-latest
en_US: mistral-small-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.002'
output: '0.006'
unit: '0.001'
currency: USD
model: open-mistral-7b
label:
zh_Hans: open-mistral-7b
en_US: open-mistral-7b
model_type: llm
features:
- agent-thought
model_properties:
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 2048
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.00025'
output: '0.00025'
unit: '0.001'
currency: USD
model: open-mixtral-8x7b
label:
zh_Hans: open-mixtral-8x7b
en_US: open-mixtral-8x7b
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.0007'
output: '0.0007'
unit: '0.001'
currency: USD
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class MistralAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='open-mistral-7b',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
provider: mistralai
label:
en_US: MistralAI
description:
en_US: Models provided by MistralAI, such as open-mistral-7b and mistral-large-latest.
zh_Hans: MistralAI 提供的模型,例如 open-mistral-7b 和 mistral-large-latest。
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from MistralAI
zh_Hans: 从 MistralAI 获取 API Key
url:
en_US: https://console.mistral.ai/api-keys/
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
...@@ -13,6 +13,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): ...@@ -13,6 +13,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \ stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials) self._add_custom_parameters(credentials)
user = user[:32] if user else None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.0005' input: '0.0005'
output: '0.0015' output: '0.0015'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.0015' input: '0.0015'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.001' input: '0.001'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 16385 max: 16385
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.003' input: '0.003'
output: '0.004' output: '0.004'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 16385 max: 16385
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.003' input: '0.003'
output: '0.004' output: '0.004'
......
...@@ -21,6 +21,8 @@ parameter_rules: ...@@ -21,6 +21,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.0015' input: '0.0015'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.001' input: '0.001'
output: '0.002' output: '0.002'
......
...@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio ...@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI ...@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
""" """
...@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ...@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
user=user user=user
) )
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
# handle fine tune remote models
base_model = model
if model.startswith('ft:'):
base_model = model.split(':')[1]
# get model mode
model_mode = self.get_model_mode(base_model, credentials)
# transform response format
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
stop = stop or []
if model_mode == LLMMode.CHAT:
# chat model
self._transform_chat_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
else:
self._transform_completion_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def _transform_completion_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# override the last user message
user_message = None
for i in range(len(prompt_messages) - 1, -1, -1):
if isinstance(prompt_messages[i], UserPromptMessage):
user_message = prompt_messages[i]
break
if user_message:
if prompt_messages[i].content[-11:] == 'Assistant: ':
# now we are in the chat app, remove the last assistant message
prompt_messages[i].content = prompt_messages[i].content[:-11]
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
else:
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"\n```{response_format}\n"
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
......
...@@ -13,6 +13,7 @@ from dashscope.common.error import ( ...@@ -13,6 +13,7 @@ from dashscope.common.error import (
) )
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
-> LLMResult | Generator:
"""
Wrapper for code block mode
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
mode = self.get_model_mode(model, credentials)
if mode == LLMMode.CHAT:
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
else:
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=response
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
...@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
""" """
extra_model_kwargs = {} extra_model_kwargs = {}
if stop: if stop:
extra_model_kwargs['stop_sequences'] = stop extra_model_kwargs['stop'] = stop
# transform credentials to kwargs for model instance # transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
...@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
params = { params = {
'model': model, 'model': model,
**model_parameters, **model_parameters,
**credentials_kwargs **credentials_kwargs,
**extra_model_kwargs,
} }
mode = self.get_model_mode(model, credentials) mode = self.get_model_mode(model, credentials)
......
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -56,6 +56,8 @@ parameter_rules: ...@@ -56,6 +56,8 @@ parameter_rules:
help: help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.02' input: '0.02'
output: '0.02' output: '0.02'
......
...@@ -57,6 +57,8 @@ parameter_rules: ...@@ -57,6 +57,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.008' input: '0.008'
output: '0.008' output: '0.008'
......
...@@ -25,6 +25,8 @@ parameter_rules: ...@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search - name: disable_search
label: label:
zh_Hans: 禁用搜索 zh_Hans: 禁用搜索
......
...@@ -25,6 +25,8 @@ parameter_rules: ...@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search - name: disable_search
label: label:
zh_Hans: 禁用搜索 zh_Hans: 禁用搜索
......
...@@ -25,3 +25,5 @@ parameter_rules: ...@@ -25,3 +25,5 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
...@@ -34,3 +34,5 @@ parameter_rules: ...@@ -34,3 +34,5 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。 zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search. en_US: Disable the model to perform external search.
required: false required: false
- name: response_format
use_template: response_format
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Optional, Union, cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( ...@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
RateLimitReachedError, RateLimitReachedError,
) )
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class ErnieBotLarguageModel(LargeLanguageModel): <instructions>
{{instructions}}
</instructions>
You should also complete the text started with ``` but not tell ``` directly.
"""
class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
...@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel): ...@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel):
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
response_format = model_parameters['response_format']
stop = stop or []
self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format)
model_parameters.pop('response_format')
if stream:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
)
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts to model prompts
"""
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += "\n```JSON\n{\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content="```JSON\n{\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int: tools: list[PromptMessageTool] | None = None) -> int:
# tools is not supported yet # tools is not supported yet
......
...@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp ...@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
And you should always end the block with a "```" to indicate the end of the JSON object.
<instructions>
{{instructions}}
</instructions>
```JSON"""
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
...@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model # invoke model
# stop = stop or []
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
# def _transform_json_prompts(self, model: str, credentials: dict,
# prompt_messages: list[PromptMessage], model_parameters: dict,
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
# stream: bool = True, user: str | None = None) \
# -> None:
# """
# Transform json prompts to model prompts
# """
# if "}\n\n" not in stop:
# stop.append("}\n\n")
# # check if there is a system message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# # override the system message
# prompt_messages[0] = SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
# )
# else:
# # insert the system message
# prompt_messages.insert(0, SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
# ))
# # check if the last message is a user message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# # add ```JSON\n to the last message
# prompt_messages[-1].content += "\n```JSON\n"
# else:
# # append a user message
# prompt_messages.append(UserPromptMessage(
# content="```JSON\n"
# ))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
...@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
""" """
extra_model_kwargs = {} extra_model_kwargs = {}
if stop: if stop:
extra_model_kwargs['stop_sequences'] = stop extra_model_kwargs['stop'] = stop
client = ZhipuAI( client = ZhipuAI(
api_key=credentials_kwargs['api_key'] api_key=credentials_kwargs['api_key']
...@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
] ]
if stream: if stream:
response = client.chat.completions.create(stream=stream, **params) response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs)
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
response = client.chat.completions.create(**params) response = client.chat.completions.create(**params, **extra_model_kwargs)
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
def _handle_generate_response(self, model: str, def _handle_generate_response(self, model: str,
......
from typing import Any, cast from typing import Any
from flask import current_app from flask import current_app
...@@ -14,7 +14,7 @@ class Keyword: ...@@ -14,7 +14,7 @@ class Keyword:
self._keyword_processor = self._init_keyword() self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword: def _init_keyword(self) -> BaseKeyword:
config = cast(dict, current_app.config) config = current_app.config
keyword_type = config.get('KEYWORD_STORE') keyword_type = config.get('KEYWORD_STORE')
if not keyword_type: if not keyword_type:
...@@ -25,7 +25,7 @@ class Keyword: ...@@ -25,7 +25,7 @@ class Keyword:
dataset=self._dataset dataset=self._dataset
) )
else: else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") raise ValueError(f"Keyword store {keyword_type} is not supported.")
def create(self, texts: list[Document], **kwargs): def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs) self._keyword_processor.create(texts, **kwargs)
......
...@@ -2,7 +2,6 @@ import threading ...@@ -2,7 +2,6 @@ import threading
from typing import Optional from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from flask_login import current_user
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
...@@ -27,6 +26,11 @@ class RetrievalService: ...@@ -27,6 +26,11 @@ class RetrievalService:
@classmethod @classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str, def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = [] all_documents = []
threads = [] threads = []
# retrieval_model source with keyword # retrieval_model source with keyword
...@@ -35,7 +39,8 @@ class RetrievalService: ...@@ -35,7 +39,8 @@ class RetrievalService:
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id, 'dataset_id': dataset_id,
'query': query, 'query': query,
'top_k': top_k 'top_k': top_k,
'all_documents': all_documents
}) })
threads.append(keyword_thread) threads.append(keyword_thread)
keyword_thread.start() keyword_thread.start()
...@@ -73,7 +78,7 @@ class RetrievalService: ...@@ -73,7 +78,7 @@ class RetrievalService:
thread.join() thread.join()
if retrival_method == 'hybrid_search': if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke( all_documents = data_post_processor.invoke(
query=query, query=query,
documents=all_documents, documents=all_documents,
...@@ -96,7 +101,7 @@ class RetrievalService: ...@@ -96,7 +101,7 @@ class RetrievalService:
documents = keyword.search( documents = keyword.search(
query, query,
k=top_k top_k=top_k
) )
all_documents.extend(documents) all_documents.extend(documents)
...@@ -116,7 +121,7 @@ class RetrievalService: ...@@ -116,7 +121,7 @@ class RetrievalService:
documents = vector.search_by_vector( documents = vector.search_by_vector(
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
k=top_k, top_k=top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
filter={ filter={
'group_id': [dataset.id] 'group_id': [dataset.id]
......
...@@ -7,4 +7,4 @@ class Field(Enum): ...@@ -7,4 +7,4 @@ class Field(Enum):
GROUP_KEY = "group_id" GROUP_KEY = "group_id"
VECTOR = "vector" VECTOR = "vector"
TEXT_KEY = "text" TEXT_KEY = "text"
PRIMARY_KEY = " id" PRIMARY_KEY = "id"
...@@ -124,12 +124,23 @@ class MilvusVector(BaseVector): ...@@ -124,12 +124,23 @@ class MilvusVector(BaseVector):
def delete_by_ids(self, doc_ids: list[str]) -> None: def delete_by_ids(self, doc_ids: list[str]) -> None:
self._client.delete(collection_name=self._collection_name, pks=doc_ids) result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None: def delete(self) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
from pymilvus import utility from pymilvus import utility
utility.drop_collection(self._collection_name, None) utility.drop_collection(self._collection_name, None, using=alias)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
......
from typing import Any, cast import json
from typing import Any
from flask import current_app from flask import current_app
...@@ -22,7 +23,7 @@ class Vector: ...@@ -22,7 +23,7 @@ class Vector:
self._vector_processor = self._init_vector() self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector: def _init_vector(self) -> BaseVector:
config = cast(dict, current_app.config) config = current_app.config
vector_type = config.get('VECTOR_STORE') vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict: if self._dataset.index_struct_dict:
...@@ -39,6 +40,11 @@ class Vector: ...@@ -39,6 +40,11 @@ class Vector:
else: else:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return WeaviateVector( return WeaviateVector(
collection_name=collection_name, collection_name=collection_name,
config=WeaviateConfig( config=WeaviateConfig(
...@@ -66,6 +72,13 @@ class Vector: ...@@ -66,6 +72,13 @@ class Vector:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
if not self._dataset.index_struct_dict:
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return QdrantVector( return QdrantVector(
collection_name=collection_name, collection_name=collection_name,
group_id=self._dataset.id, group_id=self._dataset.id,
...@@ -84,6 +97,11 @@ class Vector: ...@@ -84,6 +97,11 @@ class Vector:
else: else:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return MilvusVector( return MilvusVector(
collection_name=collection_name, collection_name=collection_name,
config=MilvusConfig( config=MilvusConfig(
......
...@@ -127,7 +127,10 @@ class WeaviateVector(BaseVector): ...@@ -127,7 +127,10 @@ class WeaviateVector(BaseVector):
) )
def delete(self): def delete(self):
self._client.schema.delete_class(self._collection_name) # check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
collection_name = self._collection_name collection_name = self._collection_name
...@@ -147,10 +150,11 @@ class WeaviateVector(BaseVector): ...@@ -147,10 +150,11 @@ class WeaviateVector(BaseVector):
return True return True
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
self._client.data_object.delete( for uuid in ids:
ids, self._client.data_object.delete(
class_name=self._collection_name class_name=self._collection_name,
) uuid=uuid,
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate.""" """Look up similar documents by embedding vector in Weaviate."""
......
...@@ -12,6 +12,7 @@ class NotionInfo(BaseModel): ...@@ -12,6 +12,7 @@ class NotionInfo(BaseModel):
notion_obj_id: str notion_obj_id: str
notion_page_type: str notion_page_type: str
document: Document = None document: Document = None
tenant_id: str
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
......
...@@ -132,7 +132,8 @@ class ExtractProcessor: ...@@ -132,7 +132,8 @@ class ExtractProcessor:
notion_workspace_id=extract_setting.notion_info.notion_workspace_id, notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id, notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type, notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document document_model=extract_setting.notion_info.document,
tenant_id=extract_setting.notion_info.tenant_id,
) )
return extractor.extract() return extractor.extract()
else: else:
......
"""Abstract interface for document loader implementations.""" """Abstract interface for document loader implementations."""
from typing import Optional from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document from core.rag.models.document import Document
class HtmlExtractor(BaseExtractor): class HtmlExtractor(BaseExtractor):
"""Load html files.
"""
Load html files.
Args: Args:
...@@ -15,57 +16,19 @@ class HtmlExtractor(BaseExtractor): ...@@ -15,57 +16,19 @@ class HtmlExtractor(BaseExtractor):
""" """
def __init__( def __init__(
self, self,
file_path: str, file_path: str
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column
self.csv_args = csv_args or {}
def extract(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load data into document objects.""" return [Document(page_content=self._load_as_text())]
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
return docs
def _read_from_file(self, csvfile) -> list[Document]: def _load_as_text(self) -> str:
docs = [] with open(self._file_path, "rb") as fp:
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore soup = BeautifulSoup(fp, 'html.parser')
for i, row in enumerate(csv_reader): text = soup.get_text()
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) text = text.strip() if text else ''
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs return text
\ No newline at end of file
...@@ -4,7 +4,6 @@ from typing import Any, Optional ...@@ -4,7 +4,6 @@ from typing import Any, Optional
import requests import requests
from flask import current_app from flask import current_app
from flask_login import current_user
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
...@@ -30,8 +29,10 @@ class NotionExtractor(BaseExtractor): ...@@ -30,8 +29,10 @@ class NotionExtractor(BaseExtractor):
notion_workspace_id: str, notion_workspace_id: str,
notion_obj_id: str, notion_obj_id: str,
notion_page_type: str, notion_page_type: str,
tenant_id: str,
document_model: Optional[DocumentModel] = None, document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None notion_access_token: Optional[str] = None,
): ):
self._notion_access_token = None self._notion_access_token = None
self._document_model = document_model self._document_model = document_model
...@@ -41,7 +42,7 @@ class NotionExtractor(BaseExtractor): ...@@ -41,7 +42,7 @@ class NotionExtractor(BaseExtractor):
if notion_access_token: if notion_access_token:
self._notion_access_token = notion_access_token self._notion_access_token = notion_access_token
else: else:
self._notion_access_token = self._get_access_token(current_user.current_tenant_id, self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id) self._notion_workspace_id)
if not self._notion_access_token: if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
......
import base64
import hashlib
import hmac
import json
import queue
import ssl
from datetime import datetime
from time import mktime
from typing import Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket
class SparkLLMClient:
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
domain = 'spark-api.xf-yun.com'
endpoint = 'chat'
if api_domain:
domain = api_domain
if model_name == 'spark-v3':
endpoint = 'multimodal'
model_api_configs = {
'spark': {
'version': 'v1.1',
'chat_domain': 'general'
},
'spark-v2': {
'version': 'v2.1',
'chat_domain': 'generalv2'
},
'spark-v3': {
'version': 'v3.1',
'chat_domain': 'generalv3'
},
'spark-v3.5': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
}
}
api_version = model_api_configs[model_name]['version']
self.chat_domain = model_api_configs[model_name]['chat_domain']
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,
urlparse(self.api_base).path,
self.api_base,
api_key,
api_secret
)
self.queue = queue.Queue()
self.blocking_message = ''
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
# generate timestamp by RFC1123
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + path + " HTTP/1.1"
# encrypt using hmac-sha256
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
}
# generate url
url = api_base + '?' + urlencode(v)
return url
def run(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None, streaming: bool = False):
websocket.enableTrace(False)
ws = websocket.WebSocketApp(
self.ws_url,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
on_open=self.on_open
)
ws.messages = messages
ws.user_id = user_id
ws.model_kwargs = model_kwargs
ws.streaming = streaming
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_error(self, ws, error):
self.queue.put({
'status_code': error.status_code,
'error': error.resp_body.decode('utf-8')
})
ws.close()
def on_close(self, ws, close_status_code, close_reason):
self.queue.put({'done': True})
def on_open(self, ws):
self.blocking_message = ''
data = json.dumps(self.gen_params(
messages=ws.messages,
user_id=ws.user_id,
model_kwargs=ws.model_kwargs
))
ws.send(data)
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
self.queue.put({
'status_code': 400,
'error': f"Code: {code}, Error: {data['header']['message']}"
})
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
if ws.streaming:
self.queue.put({'data': content})
else:
self.blocking_message += content
if status == 2:
if not ws.streaming:
self.queue.put({'data': self.blocking_message})
ws.close()
def gen_params(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None) -> dict:
data = {
"header": {
"app_id": self.app_id,
"uid": user_id
},
"parameter": {
"chat": {
"domain": self.chat_domain
}
},
"payload": {
"message": {
"text": messages
}
}
}
if model_kwargs:
data['parameter']['chat'].update(model_kwargs)
return data
def subscribe(self):
while True:
content = self.queue.get()
if 'error' in content:
if content['status_code'] == 401:
raise SparkError('[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.')
elif content['status_code'] == 403:
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions.")
else:
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
if 'data' not in content:
break
yield content
class SparkError(Exception):
pass
from datetime import datetime
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class DatetimeToolInput(BaseModel):
type: str = Field(..., description="Type for current time, must be: datetime.")
class DatetimeTool(BaseTool):
"""Tool for querying current datetime."""
name: str = "current_datetime"
args_schema: type[BaseModel] = DatetimeToolInput
description: str = "A tool when you want to get the current date, time, week, month or year, " \
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
def _run(self, type: str) -> str:
# get current time
current_time = datetime.utcnow()
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and tool_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.tool_name == self.get_provider_name().value
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]
class ToolValidateFailedError(Exception):
description = "Tool Provider Validate failed"
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.errors import ToolValidateFailedError
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper
from models.tool import ToolProviderName
class SerpAPIToolProvider(BaseToolProvider):
def get_provider_name(self) -> ToolProviderName:
"""
Returns the name of the provider.
:return:
"""
return ToolProviderName.SERPAPI
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for SerpAPI as a dictionary.
:param obfuscated: obfuscate credentials if True
:return:
"""
tool_provider = self.get_provider(must_enabled=True)
if not tool_provider:
return None
credentials = tool_provider.credentials
if not credentials:
return None
if credentials.get('api_key'):
credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated)
return credentials
def credentials_to_func_kwargs(self) -> Optional[dict]:
"""
Returns the credentials function kwargs as a dictionary.
:return:
"""
credentials = self.get_credentials()
if not credentials:
return None
return {
'serpapi_api_key': credentials.get('api_key')
}
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:return:
"""
if 'api_key' not in credentials or not credentials.get('api_key'):
raise ToolValidateFailedError("SerpAPI api_key is required.")
api_key = credentials.get('api_key')
try:
OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test')
except Exception as e:
raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e))
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
credentials['api_key'] = self.encrypt_token(credentials.get('api_key'))
return credentials
from typing import Optional
from core.tool.provider.base import BaseToolProvider
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
class ToolProviderService:
def __init__(self, tenant_id: str, provider_name: str):
self.provider = self._init_provider(tenant_id, provider_name)
def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider:
if provider_name == 'serpapi':
return SerpAPIToolProvider(tenant_id)
else:
raise Exception('tool provider {} not found'.format(provider_name))
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Returns the credentials for Tool as a dictionary.
:param obfuscated:
:return:
"""
return self.provider.get_credentials(obfuscated)
def credentials_validate(self, credentials: dict):
"""
Validates the given credentials.
:param credentials:
:raises: ValidateFailedError
"""
return self.provider.credentials_validate(credentials)
def encrypt_credentials(self, credentials: dict):
"""
Encrypts the given credentials.
:param credentials:
:return:
"""
return self.provider.encrypt_credentials(credentials)
from langchain import SerpAPIWrapper
from pydantic import BaseModel, Field
class OptimizedSerpAPIInput(BaseModel):
query: str = Field(..., description="search query.")
class OptimizedSerpAPIWrapper(SerpAPIWrapper):
@staticmethod
def _process_response(res: dict, num_results: int = 5) -> str:
"""Process response from SerpAPI."""
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
res["answer_box"] = res["answer_box"][0]
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
toret = res["sports_results"]["game_spotlight"]
elif (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
toret = res["shopping_results"][:3]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
toret = res["knowledge_graph"]["description"]
elif 'organic_results' in res.keys() and len(res['organic_results']) > 0:
toret = ""
for result in res["organic_results"][:num_results]:
if "link" in result:
toret += "----------------\nlink: " + result["link"] + "\n"
if "snippet" in result:
toret += "snippet: " + result["snippet"] + "\n"
else:
toret = "No good search result found"
return "search result:\n" + toret
import hashlib
import json
import os
import re
import site
import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from typing import Any
import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """
TITLE: {title}
AUTHORS: {authors}
PUBLISH DATE: {publish_date}
TOP_IMAGE_URL: {top_image}
TEXT:
{text}
"""
class WebReaderToolInput(BaseModel):
url: str = Field(..., description="URL of the website to read")
summary: bool = Field(
default=False,
description="When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor: int = Field(
default=0,
description="Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
"The value cannot exceed 24000.",
)
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "web_reader"
args_schema: type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
"there is no need to use."
page_contents: str = None
url: str = None
max_chunk_length: int = 4000
summary_chunk_tokens: int = 4000
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
continue_reading: bool = True
model_config: ModelConfigEntity
model_parameters: dict[str, Any]
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
if not self.page_contents or self.url != url:
page_contents = get_url(url)
self.page_contents = page_contents
self.url = url
else:
page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators
)
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
return "No content found."
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]
chain = self.get_summary_chain()
try:
page_contents = chain.run(docs)
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
async def _arun(self, url: str) -> str:
raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.PROMPT,
parameters=self.model_parameters
)
refine_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.REFINE_PROMPT,
parameters=self.model_parameters
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length]
def get_url(url: str) -> str:
"""Fetch URL and return the contents as a string."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
if head_response.status_code != 200:
return "URL returned status code {}.".format(head_response.status_code)
# check content-type
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30))
a = extract_using_readabilipy(response.text)
if not a['plain_text'] or not a['plain_text'].strip():
return get_url_from_newspaper3k(url)
res = FULL_TEMPLATE.format(
title=a['title'],
authors=a['byline'],
publish_date=a['date'],
top_image="",
text=a['plain_text'] if a['plain_text'] else "",
)
return res
def get_url_from_newspaper3k(url: str) -> str:
a = Article(url)
a.download()
a.parse()
res = FULL_TEMPLATE.format(
title=a.title,
authors=a.authors,
publish_date=a.publish_date,
top_image=a.top_image,
text=a.text,
)
return res
def extract_using_readabilipy(html):
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
f_html.write(html)
f_html.close()
html_path = f_html.name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path = html_path + ".json"
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
with chdir(jsdir):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
# Read output of call to Readability.parse() from JSON file and return as Python dictionary
with open(article_json_path, encoding="utf-8") as json_file:
input_json = json.loads(json_file.read())
# Deleting files after processing
os.unlink(article_json_path)
os.unlink(html_path)
article_json = {
"title": None,
"byline": None,
"date": None,
"content": None,
"plain_content": None,
"plain_text": None
}
# Populate article fields from readability fields where present
if input_json:
if "title" in input_json and input_json["title"]:
article_json["title"] = input_json["title"]
if "byline" in input_json and input_json["byline"]:
article_json["byline"] = input_json["byline"]
if "date" in input_json and input_json["date"]:
article_json["date"] = input_json["date"]
if "content" in input_json and input_json["content"]:
article_json["content"] = input_json["content"]
article_json["plain_content"] = plain_content(article_json["content"], False, False)
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
if "textContent" in input_json and input_json["textContent"]:
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
return article_json
def find_module_path(module_name):
for package_path in site.getsitepackages():
potential_path = os.path.join(package_path, module_name)
if os.path.exists(potential_path):
return potential_path
return None
@contextmanager
def chdir(path):
"""Change directory in context and return to original on exit"""
# From https://stackoverflow.com/a/37996581, couldn't find a built-in
original_path = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_path)
def extract_text_blocks_as_plain_text(paragraph_html):
# Load article as DOM
soup = BeautifulSoup(paragraph_html, 'html.parser')
# Select all lists
list_elements = soup.find_all(['ul', 'ol'])
# Prefix text in all list items with "* " and make lists paragraphs
for list_element in list_elements:
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
list_element.string = plain_items
list_element.name = "p"
# Select all text blocks
text_blocks = [s.parent for s in soup.find_all(string=True)]
text_blocks = [plain_text_leaf_node(block) for block in text_blocks]
# Drop empty paragraphs
text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks))
return text_blocks
def plain_text_leaf_node(element):
# Extract all text, stripped of any child HTML elements and normalise it
plain_text = normalise_text(element.get_text())
if plain_text != "" and element.name == "li":
plain_text = "* {}, ".format(plain_text)
if plain_text == "":
plain_text = None
if "data-node-index" in element.attrs:
plain = {"node_index": element["data-node-index"], "text": plain_text}
else:
plain = {"text": plain_text}
return plain
def plain_content(readability_content, content_digests, node_indexes):
# Load article as DOM
soup = BeautifulSoup(readability_content, 'html.parser')
# Make all elements plain
elements = plain_elements(soup.contents, content_digests, node_indexes)
if node_indexes:
# Add node index attributes to nodes
elements = [add_node_indexes(element) for element in elements]
# Replace article contents with plain elements
soup.contents = elements
return str(soup)
def plain_elements(elements, content_digests, node_indexes):
# Get plain content versions of all elements
elements = [plain_element(element, content_digests, node_indexes)
for element in elements]
if content_digests:
# Add content digest attribute to nodes
elements = [add_content_digest(element) for element in elements]
return elements
def plain_element(element, content_digests, node_indexes):
# For lists, we make each item plain text
if is_leaf(element):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text = element.get_text()
# 2. Normalise the extracted text string to a canonical representation
plain_text = normalise_text(plain_text)
# 3. Update element content to be plain text
element.string = plain_text
elif is_text(element):
if is_non_printing(element):
# The simplified HTML may have come from Readability.js so might
# have non-printing text (e.g. Comment or CData). In this case, we
# keep the structure, but ensure that the string is empty.
element = type(element)("")
else:
plain_text = element.string
plain_text = normalise_text(plain_text)
element = type(element)(plain_text)
else:
# If not a leaf node or leaf type call recursively on child nodes, replacing
element.contents = plain_elements(element.contents, content_digests, node_indexes)
return element
def add_node_indexes(element, node_index="0"):
# Can't add attributes to string types
if is_text(element):
return element
# Add index to current element
element["data-node-index"] = node_index
# Add index to child elements
for local_idx, child in enumerate(
[c for c in element.contents if not is_text(c)], start=1):
# Can't add attributes to leaf string types
child_index = "{stem}.{local}".format(
stem=node_index, local=local_idx)
add_node_indexes(child, node_index=child_index)
return element
def normalise_text(text):
"""Normalise unicode and whitespace."""
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
text = strip_control_characters(text)
text = normalise_unicode(text)
text = normalise_whitespace(text)
return text
def strip_control_characters(text):
"""Strip out unicode control characters which might break the parsing."""
# Unicode control characters
# [Cc]: Other, Control [includes new lines]
# [Cf]: Other, Format
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs'])
retained_chars = ['\t', '\n', '\r', '\f']
# Remove non-printing control characters
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
def normalise_unicode(text):
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
def normalise_whitespace(text):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text = regex.sub(r"\s+", " ", text)
# Remove leading and trailing whitespace
text = text.strip()
return text
def is_leaf(element):
return (element.name in ['p', 'li'])
def is_text(element):
return isinstance(element, NavigableString)
def is_non_printing(element):
return any(isinstance(element, _e) for _e in [Comment, CData])
def add_content_digest(element):
if not is_text(element):
element["data-content-digest"] = content_digest(element)
return element
def content_digest(element):
if is_text(element):
# Hash
trimmed_string = element.string.strip()
if trimmed_string == "":
digest = ""
else:
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
else:
contents = element.contents
num_contents = len(contents)
if num_contents == 0:
# No hash when no child elements exist
digest = ""
elif num_contents == 1:
# If single child, use digest of child
digest = content_digest(contents[0])
else:
# Build content digest from the "non-empty" digests of child nodes
digest = hashlib.sha256()
child_digests = list(
filter(lambda x: x != "", [content_digest(content) for content in contents]))
for child in child_digests:
digest.update(child.encode('utf-8'))
digest = digest.hexdigest()
return digest
...@@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController): ...@@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController):
en_US='The api key', en_US='The api key',
zh_Hans='api key的值' zh_Hans='api key的值'
) )
),
'api_key_header_prefix': ToolProviderCredentials(
name='api_key_header_prefix',
required=False,
default='basic',
type=ToolProviderCredentials.CredentialsType.SELECT,
help=I18nObject(
en_US='The prefix of the api key header',
zh_Hans='api key header 的前缀'
),
options=[
ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
]
) )
} }
elif auth_type == ApiProviderAuthType.NONE: elif auth_type == ApiProviderAuthType.NONE:
......
...@@ -62,6 +62,17 @@ class ApiTool(Tool): ...@@ -62,6 +62,17 @@ class ApiTool(Tool):
if 'api_key_value' not in credentials: if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value') raise ToolProviderCredentialValidationError('Missing api_key_value')
elif not isinstance(credentials['api_key_value'], str):
raise ToolProviderCredentialValidationError('api_key_value must be a string')
if 'api_key_header_prefix' in credentials:
api_key_header_prefix = credentials['api_key_header_prefix']
if api_key_header_prefix == 'basic':
credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}'
elif api_key_header_prefix == 'bearer':
credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
elif api_key_header_prefix == 'custom':
pass
headers[api_key_header] = credentials['api_key_value'] headers[api_key_header] = credentials['api_key_value']
...@@ -200,7 +211,7 @@ class ApiTool(Tool): ...@@ -200,7 +211,7 @@ class ApiTool(Tool):
# replace path parameters # replace path parameters
for name, value in path_params.items(): for name, value in path_params.items():
url = url.replace(f'{{{name}}}', value) url = url.replace(f'{{{name}}}', f'{value}')
# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
if 'Content-Type' in headers: if 'Content-Type' in headers:
......
...@@ -171,7 +171,7 @@ class DatasetMultiRetrieverTool(BaseTool): ...@@ -171,7 +171,7 @@ class DatasetMultiRetrieverTool(BaseTool):
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=self.top_k top_k=self.top_k
......
...@@ -69,7 +69,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -69,7 +69,7 @@ class DatasetRetrieverTool(BaseTool):
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=self.top_k top_k=self.top_k
......
...@@ -4,7 +4,7 @@ from langchain.tools import BaseTool ...@@ -4,7 +4,7 @@ from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
...@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool): ...@@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool):
@staticmethod @staticmethod
def get_dataset_tools(tenant_id: str, def get_dataset_tools(tenant_id: str,
dataset_ids: list[str], dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity, retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler hit_callback: DatasetIndexToolCallbackHandler
) -> list['DatasetRetrieverTool']: ) -> list['DatasetRetrieverTool']:
""" """
get dataset tool get dataset tool
""" """
...@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool): ...@@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool):
) )
# restore retrieve strategy # restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools # convert langchain tools to Tools
tools = [] tools = []
for langchain_tool in langchain_tools: for langchain_tool in langchain_tools:
...@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool): ...@@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool):
llm=langchain_tool.description), llm=langchain_tool.description),
runtime=DatasetRetrieverTool.Runtime() runtime=DatasetRetrieverTool.Runtime()
) )
tools.append(tool) tools.append(tool)
return tools return tools
...@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool): ...@@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool):
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(self) -> list[ToolParameter]:
return [ return [
ToolParameter(name='query', ToolParameter(name='query',
label=I18nObject(en_US='', zh_Hans=''), label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''), human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParameter.ToolParameterType.STRING, type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.', llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True, required=True,
default=''), default=''),
] ]
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
...@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool): ...@@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool):
query = tool_parameters.get('query', None) query = tool_parameters.get('query', None)
if not query: if not query:
return self.create_text_message(text='please input query') return self.create_text_message(text='please input query')
# invoke dataset retriever tool # invoke dataset retriever tool
result = self.langchain_tool._run(query=query) result = self.langchain_tool._run(query=query)
...@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool): ...@@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool):
""" """
validate the credentials for dataset retriever tool validate the credentials for dataset retriever tool
""" """
pass pass
\ No newline at end of file
import re
import uuid
from json import loads as json_loads from json import loads as json_loads
from requests import get from requests import get
...@@ -46,7 +48,7 @@ class ApiBasedToolSchemaParser: ...@@ -46,7 +48,7 @@ class ApiBasedToolSchemaParser:
parameters = [] parameters = []
if 'parameters' in interface['operation']: if 'parameters' in interface['operation']:
for parameter in interface['operation']['parameters']: for parameter in interface['operation']['parameters']:
parameters.append(ToolParameter( tool_parameter = ToolParameter(
name=parameter['name'], name=parameter['name'],
label=I18nObject( label=I18nObject(
en_US=parameter['name'], en_US=parameter['name'],
...@@ -61,7 +63,14 @@ class ApiBasedToolSchemaParser: ...@@ -61,7 +63,14 @@ class ApiBasedToolSchemaParser:
form=ToolParameter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get('description'), llm_description=parameter.get('description'),
default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None, default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None,
)) )
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle # create tool bundle
# check if there is a request body # check if there is a request body
if 'requestBody' in interface['operation']: if 'requestBody' in interface['operation']:
...@@ -80,13 +89,14 @@ class ApiBasedToolSchemaParser: ...@@ -80,13 +89,14 @@ class ApiBasedToolSchemaParser:
root = root[ref] root = root[ref]
# overwrite the content # overwrite the content
interface['operation']['requestBody']['content'][content_type]['schema'] = root interface['operation']['requestBody']['content'][content_type]['schema'] = root
# parse body parameters # parse body parameters
if 'schema' in interface['operation']['requestBody']['content'][content_type]: if 'schema' in interface['operation']['requestBody']['content'][content_type]:
body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
required = body_schema['required'] if 'required' in body_schema else [] required = body_schema['required'] if 'required' in body_schema else []
properties = body_schema['properties'] if 'properties' in body_schema else {} properties = body_schema['properties'] if 'properties' in body_schema else {}
for name, property in properties.items(): for name, property in properties.items():
parameters.append(ToolParameter( tool = ToolParameter(
name=name, name=name,
label=I18nObject( label=I18nObject(
en_US=name, en_US=name,
...@@ -101,7 +111,14 @@ class ApiBasedToolSchemaParser: ...@@ -101,7 +111,14 @@ class ApiBasedToolSchemaParser:
form=ToolParameter.ToolParameterForm.LLM, form=ToolParameter.ToolParameterForm.LLM,
llm_description=property['description'] if 'description' in property else '', llm_description=property['description'] if 'description' in property else '',
default=property['default'] if 'default' in property else None, default=property['default'] if 'default' in property else None,
)) )
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated # check if parameters is duplicated
parameters_count = {} parameters_count = {}
...@@ -119,7 +136,11 @@ class ApiBasedToolSchemaParser: ...@@ -119,7 +136,11 @@ class ApiBasedToolSchemaParser:
path = interface['path'] path = interface['path']
if interface['path'].startswith('/'): if interface['path'].startswith('/'):
path = interface['path'][1:] path = interface['path'][1:]
path = path.replace('/', '_') # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r'[^a-zA-Z0-9_-]', '', path)
if not path:
path = str(uuid.uuid4())
interface['operation']['operationId'] = f'{path}_{interface["method"]}' interface['operation']['operationId'] = f'{path}_{interface["method"]}'
bundles.append(ApiBasedToolBundle( bundles.append(ApiBasedToolBundle(
...@@ -134,7 +155,23 @@ class ApiBasedToolSchemaParser: ...@@ -134,7 +155,23 @@ class ApiBasedToolSchemaParser:
)) ))
return bundles return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
parameter = parameter or {}
typ = None
if 'type' in parameter:
typ = parameter['type']
elif 'schema' in parameter and 'type' in parameter['schema']:
typ = parameter['schema']['type']
if typ == 'integer' or typ == 'number':
return ToolParameter.ToolParameterType.NUMBER
elif typ == 'boolean':
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == 'string':
return ToolParameter.ToolParameterType.STRING
@staticmethod @staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]:
""" """
......
...@@ -7,23 +7,14 @@ import subprocess ...@@ -7,23 +7,14 @@ import subprocess
import tempfile import tempfile
import unicodedata import unicodedata
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any
import requests import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article from newspaper import Article
from pydantic import BaseModel, Field
from regex import regex from regex import regex
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
...@@ -36,106 +27,6 @@ TEXT: ...@@ -36,106 +27,6 @@ TEXT:
""" """
class WebReaderToolInput(BaseModel):
url: str = Field(..., description="URL of the website to read")
summary: bool = Field(
default=False,
description="When the user's question requires extracting the summarizing content of the webpage, "
"set it to true."
)
cursor: int = Field(
default=0,
description="Start reading from this character."
"Use when the first response was truncated"
"and you want to continue reading the page."
"The value cannot exceed 24000.",
)
class WebReaderTool(BaseTool):
"""Reader tool for getting website title and contents. Gives more control than SimpleReaderTool."""
name: str = "web_reader"
args_schema: type[BaseModel] = WebReaderToolInput
description: str = "use this to read a website. " \
"If you can answer the question based on the information provided, " \
"there is no need to use."
page_contents: str = None
url: str = None
max_chunk_length: int = 4000
summary_chunk_tokens: int = 4000
summary_chunk_overlap: int = 0
summary_separators: list[str] = ["\n\n", "。", ".", " ", ""]
continue_reading: bool = True
model_config: ModelConfigEntity
model_parameters: dict[str, Any]
def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str:
try:
if not self.page_contents or self.url != url:
page_contents = get_url(url)
self.page_contents = page_contents
self.url = url
else:
page_contents = self.page_contents
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
if summary:
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.summary_chunk_tokens,
chunk_overlap=self.summary_chunk_overlap,
separators=self.summary_separators
)
texts = character_splitter.split_text(page_contents)
docs = [Document(page_content=t) for t in texts]
if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'):
return "No content found."
# only use first 5 docs
if len(docs) > 5:
docs = docs[:5]
chain = self.get_summary_chain()
try:
page_contents = chain.run(docs)
except Exception as e:
return f'Read this website failed, caused by: {str(e)}.'
else:
page_contents = page_result(page_contents, cursor, self.max_chunk_length)
if self.continue_reading and len(page_contents) >= self.max_chunk_length:
page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \
f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \
f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING."
return page_contents
async def _arun(self, url: str) -> str:
raise NotImplementedError
def get_summary_chain(self) -> RefineDocumentsChain:
initial_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.PROMPT,
parameters=self.model_parameters
)
refine_chain = LLMChain(
model_config=self.model_config,
prompt=refine_prompts.REFINE_PROMPT,
parameters=self.model_parameters
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name="text",
initial_response_name="existing_answer",
callbacks=self.callbacks
)
def page_result(text: str, cursor: int, max_length: int) -> str: def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length] return text[cursor: cursor + max_length]
......
import re import re
import uuid import uuid
from core.agent.agent_executor import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
......
...@@ -133,8 +133,9 @@ class HitTestingService: ...@@ -133,8 +133,9 @@ class HitTestingService:
if embedding_length <= 1: if embedding_length <= 1:
return [{'x': 0, 'y': 0}] return [{'x': 0, 'y': 0}]
concatenate_data = np.array(embeddings).reshape(embedding_length, -1) noise = np.random.normal(0, 1e-4, np.array(embeddings).shape)
# concatenate_data = np.concatenate(embeddings) concatenate_data = np.array(embeddings) + noise
concatenate_data = concatenate_data.reshape(embedding_length, -1)
perplexity = embedding_length / 2 + 1 perplexity = embedding_length / 2 + 1
if perplexity >= embedding_length: if perplexity >= embedding_length:
......
from flask import current_app
from flask_login import current_user from flask_login import current_user
from extensions.ext_database import db from extensions.ext_database import db
...@@ -31,7 +33,15 @@ class WorkspaceService: ...@@ -31,7 +33,15 @@ class WorkspaceService:
can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): if can_replace_logo and TenantService.has_roles(tenant,
tenant_info['custom_config'] = tenant.custom_config_dict [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
base_url = current_app.config.get('FILES_URL')
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
tenant_info['custom_config'] = {
'remove_webapp_brand': remove_webapp_brand,
'replace_webapp_logo': replace_webapp_logo,
}
return tenant_info return tenant_info
...@@ -40,7 +40,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -40,7 +40,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
indexing_technique=indexing_technique, indexing_technique=indexing_technique,
index_struct=index_struct, index_struct=index_struct,
collection_binding_id=collection_binding_id, collection_binding_id=collection_binding_id,
doc_form=doc_form
) )
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
......
...@@ -58,7 +58,8 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -58,7 +58,8 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=data_source_binding.access_token notion_access_token=data_source_binding.access_token,
tenant_id=document.tenant_id
) )
last_edited_time = loader.get_notion_last_edited_time() last_edited_time = loader.get_notion_last_edited_time()
......
...@@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, ...@@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel
def test_predefined_models(): def test_predefined_models():
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
model_schemas = model.predefined_models() model_schemas = model.predefined_models()
assert len(model_schemas) >= 1 assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity) assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model(): def test_validate_credentials_for_chat_model():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials( model.validate_credentials(
...@@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model(): ...@@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model():
def test_invoke_model_ernie_bot(): def test_invoke_model_ernie_bot():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot(): ...@@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot():
def test_invoke_model_ernie_bot_turbo(): def test_invoke_model_ernie_bot_turbo():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot-turbo', model='ernie-bot-turbo',
...@@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo(): ...@@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo():
def test_invoke_model_ernie_8k(): def test_invoke_model_ernie_8k():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot-8k', model='ernie-bot-8k',
...@@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k(): ...@@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k():
def test_invoke_model_ernie_bot_4(): def test_invoke_model_ernie_bot_4():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot-4', model='ernie-bot-4',
...@@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4(): ...@@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4():
def test_invoke_stream_model(): def test_invoke_stream_model():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -182,7 +182,7 @@ def test_invoke_stream_model(): ...@@ -182,7 +182,7 @@ def test_invoke_stream_model():
def test_invoke_model_with_system(): def test_invoke_model_with_system():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -212,7 +212,7 @@ def test_invoke_model_with_system(): ...@@ -212,7 +212,7 @@ def test_invoke_model_with_system():
def test_invoke_with_search(): def test_invoke_with_search():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.invoke( response = model.invoke(
model='ernie-bot', model='ernie-bot',
...@@ -250,7 +250,7 @@ def test_invoke_with_search(): ...@@ -250,7 +250,7 @@ def test_invoke_with_search():
def test_get_num_tokens(): def test_get_num_tokens():
sleep(3) sleep(3)
model = ErnieBotLarguageModel() model = ErnieBotLargeLanguageModel()
response = model.get_num_tokens( response = model.get_num_tokens(
model='ernie-bot', model='ernie-bot',
......
...@@ -2,7 +2,7 @@ version: '3.1' ...@@ -2,7 +2,7 @@ version: '3.1'
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.5.6 image: langgenius/dify-api:0.5.7
restart: always restart: always
environment: environment:
# Startup mode, 'api' starts the API server. # Startup mode, 'api' starts the API server.
...@@ -135,7 +135,7 @@ services: ...@@ -135,7 +135,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.5.6 image: langgenius/dify-api:0.5.7
restart: always restart: always
environment: environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue. # Startup mode, 'worker' starts the Celery worker for processing the queue.
...@@ -206,7 +206,7 @@ services: ...@@ -206,7 +206,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.5.6 image: langgenius/dify-web:0.5.7
restart: always restart: always
environment: environment:
EDITION: SELF_HOSTED EDITION: SELF_HOSTED
......
...@@ -12,6 +12,7 @@ import { useAppContext } from '@/context/app-context' ...@@ -12,6 +12,7 @@ import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { CheckModal } from '@/hooks/use-pay' import { CheckModal } from '@/hooks/use-pay'
import TabSliderNew from '@/app/components/base/tab-slider-new' import TabSliderNew from '@/app/components/base/tab-slider-new'
import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import { SearchLg } from '@/app/components/base/icons/src/vender/line/general' import { SearchLg } from '@/app/components/base/icons/src/vender/line/general'
import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
...@@ -35,7 +36,9 @@ const getKey = ( ...@@ -35,7 +36,9 @@ const getKey = (
const Apps = () => { const Apps = () => {
const { t } = useTranslation() const { t } = useTranslation()
const { isCurrentWorkspaceManager } = useAppContext() const { isCurrentWorkspaceManager } = useAppContext()
const [activeTab, setActiveTab] = useState('all') const [activeTab, setActiveTab] = useTabSearchParams({
defaultTab: 'all',
})
const [keywords, setKeywords] = useState('') const [keywords, setKeywords] = useState('')
const [searchKeywords, setSearchKeywords] = useState('') const [searchKeywords, setSearchKeywords] = useState('')
......
'use client' 'use client'
// Libraries // Libraries
import { useRef, useState } from 'react' import { useRef } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import useSWR from 'swr' import useSWR from 'swr'
...@@ -15,6 +15,9 @@ import TabSliderNew from '@/app/components/base/tab-slider-new' ...@@ -15,6 +15,9 @@ import TabSliderNew from '@/app/components/base/tab-slider-new'
// Services // Services
import { fetchDatasetApiBaseUrl } from '@/service/datasets' import { fetchDatasetApiBaseUrl } from '@/service/datasets'
// Hooks
import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
const Container = () => { const Container = () => {
const { t } = useTranslation() const { t } = useTranslation()
...@@ -23,7 +26,9 @@ const Container = () => { ...@@ -23,7 +26,9 @@ const Container = () => {
{ value: 'api', text: t('dataset.datasetsApi') }, { value: 'api', text: t('dataset.datasetsApi') },
] ]
const [activeTab, setActiveTab] = useState('dataset') const [activeTab, setActiveTab] = useTabSearchParams({
defaultTab: 'dataset',
})
const containerRef = useRef<HTMLDivElement>(null) const containerRef = useRef<HTMLDivElement>(null)
const { data } = useSWR(activeTab === 'dataset' ? null : '/datasets/api-base-info', fetchDatasetApiBaseUrl) const { data } = useSWR(activeTab === 'dataset' ? null : '/datasets/api-base-info', fetchDatasetApiBaseUrl)
......
...@@ -42,6 +42,7 @@ const HeaderOptions: FC<Props> = ({ ...@@ -42,6 +42,7 @@ const HeaderOptions: FC<Props> = ({
const { locale } = useContext(I18n) const { locale } = useContext(I18n)
const { CSVDownloader, Type } = useCSVDownloader() const { CSVDownloader, Type } = useCSVDownloader()
const [list, setList] = useState<AnnotationItemBasic[]>([]) const [list, setList] = useState<AnnotationItemBasic[]>([])
const annotationUnavailable = list.length === 0
const listTransformer = (list: AnnotationItemBasic[]) => list.map( const listTransformer = (list: AnnotationItemBasic[]) => list.map(
(item: AnnotationItemBasic) => { (item: AnnotationItemBasic) => {
...@@ -116,11 +117,11 @@ const HeaderOptions: FC<Props> = ({ ...@@ -116,11 +117,11 @@ const HeaderOptions: FC<Props> = ({
...list.map(item => [item.question, item.answer]), ...list.map(item => [item.question, item.answer]),
]} ]}
> >
<button className={s.actionItem}> <button disabled={annotationUnavailable} className={s.actionItem}>
<span className={s.actionName}>CSV</span> <span className={s.actionName}>CSV</span>
</button> </button>
</CSVDownloader> </CSVDownloader>
<button className={cn(s.actionItem, '!border-0')} onClick={JSONLOutput}> <button disabled={annotationUnavailable} className={cn(s.actionItem, '!border-0')} onClick={JSONLOutput}>
<span className={s.actionName}>JSONL</span> <span className={s.actionName}>JSONL</span>
</button> </button>
</Menu.Items> </Menu.Items>
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
} }
.actionItem { .actionItem {
@apply h-9 py-2 px-3 mx-1 flex items-center space-x-2 hover:bg-gray-100 rounded-lg cursor-pointer; @apply h-9 py-2 px-3 mx-1 flex items-center space-x-2 hover:bg-gray-100 rounded-lg cursor-pointer disabled:opacity-50;
width: calc(100% - 0.5rem); width: calc(100% - 0.5rem);
} }
...@@ -35,4 +35,4 @@ ...@@ -35,4 +35,4 @@
left: 4px; left: 4px;
transform: translateX(-100%); transform: translateX(-100%);
box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03);
} }
\ No newline at end of file
...@@ -18,7 +18,7 @@ import { getNewVar } from '@/utils/var' ...@@ -18,7 +18,7 @@ import { getNewVar } from '@/utils/var'
import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight' import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight'
import { Plus, Trash03 } from '@/app/components/base/icons/src/vender/line/general' import { Plus, Trash03 } from '@/app/components/base/icons/src/vender/line/general'
const MAX_QUESTION_NUM = 3 const MAX_QUESTION_NUM = 5
export type IOpeningStatementProps = { export type IOpeningStatementProps = {
value: string value: string
......
import type { FC } from 'react' import type { FC } from 'react'
import classNames from 'classnames'
type LogoSiteProps = { type LogoSiteProps = {
className?: string className?: string
} }
const LogoSite: FC<LogoSiteProps> = ({ const LogoSite: FC<LogoSiteProps> = ({
className, className,
}) => { }) => {
return ( return (
<img <img
src='/logo/logo-site.png' src='/logo/logo-site.png'
className={`block w-auto h-10 ${className}`} className={classNames('block w-auto h-10', className)}
alt='logo' alt='logo'
/> />
) )
......
'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React, { useEffect } from 'react' import React from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import useSWR from 'swr'
import PlanComp from '../plan' import PlanComp from '../plan'
import { ReceiptList } from '../../base/icons/src/vender/line/financeAndECommerce' import { ReceiptList } from '../../base/icons/src/vender/line/financeAndECommerce'
import { LinkExternal01 } from '../../base/icons/src/vender/line/general' import { LinkExternal01 } from '../../base/icons/src/vender/line/general'
...@@ -12,17 +13,11 @@ import { useProviderContext } from '@/context/provider-context' ...@@ -12,17 +13,11 @@ import { useProviderContext } from '@/context/provider-context'
const Billing: FC = () => { const Billing: FC = () => {
const { t } = useTranslation() const { t } = useTranslation()
const { isCurrentWorkspaceManager } = useAppContext() const { isCurrentWorkspaceManager } = useAppContext()
const [billingUrl, setBillingUrl] = React.useState('')
const { enableBilling } = useProviderContext() const { enableBilling } = useProviderContext()
const { data: billingUrl } = useSWR(
useEffect(() => { (!enableBilling || !isCurrentWorkspaceManager) ? null : ['/billing/invoices'],
if (!enableBilling || !isCurrentWorkspaceManager) () => fetchBillingUrl().then(data => data.url),
return )
(async () => {
const { url } = await fetchBillingUrl()
setBillingUrl(url)
})()
}, [isCurrentWorkspaceManager])
return ( return (
<div> <div>
...@@ -39,4 +34,5 @@ const Billing: FC = () => { ...@@ -39,4 +34,5 @@ const Billing: FC = () => {
</div> </div>
) )
} }
export default React.memo(Billing) export default React.memo(Billing)
...@@ -16,8 +16,6 @@ import { ...@@ -16,8 +16,6 @@ import {
updateCurrentWorkspace, updateCurrentWorkspace,
} from '@/service/common' } from '@/service/common'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import { API_PREFIX } from '@/config'
import { getPurifyHref } from '@/utils'
const ALLOW_FILE_EXTENSIONS = ['svg', 'png'] const ALLOW_FILE_EXTENSIONS = ['svg', 'png']
...@@ -123,7 +121,7 @@ const CustomWebAppBrand = () => { ...@@ -123,7 +121,7 @@ const CustomWebAppBrand = () => {
POWERED BY POWERED BY
{ {
webappLogo webappLogo
? <img key={webappLogo} src={`${getPurifyHref(API_PREFIX.slice(0, -12))}/files/workspaces/${currentWorkspace.id}/webapp-logo`} alt='logo' className='ml-2 block w-auto h-5' /> ? <img key={webappLogo} src={webappLogo} alt='logo' className='ml-2 block w-auto h-5' />
: <LogoSite className='ml-2 !h-5' /> : <LogoSite className='ml-2 !h-5' />
} }
</div> </div>
......
...@@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran ...@@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran
</Col> </Col>
<Col sticky> <Col sticky>
### Request Example ### Request Example
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
......
...@@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ...@@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success - `result` (string) 固定返回 success
</Col> </Col>
<Col sticky> <Col sticky>
<CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/completion-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
......
...@@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to ...@@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to
</Col> </Col>
<Col sticky> <Col sticky>
### Request Example ### Request Example
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{"user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{"user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
...@@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to ...@@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to
- (string) url of icon - (string) url of icon
</Col> </Col>
<Col> <Col>
<CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}> <CodeGroup title="Request" tag="GET" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \ curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}' -H 'Authorization: Bearer {api_key}'
``` ```
......
...@@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ...@@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- `result` (string) 固定返回 success - `result` (string) 固定返回 success
</Col> </Col>
<Col sticky> <Col sticky>
<CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}> <CodeGroup title="Request" tag="POST" label="/chat-messages/:task_id/stop" targetCode={`curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \\\n-H 'Authorization: Bearer {api_key}' \\\n-H 'Content-Type: application/json' \\\n--data-raw '{ "user": "abc-123"}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \ curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \
-H 'Authorization: Bearer {api_key}' \ -H 'Authorization: Bearer {api_key}' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
--data-raw '{ --data-raw '{
...@@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ...@@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
- (string) 图标URL - (string) 图标URL
</Col> </Col>
<Col> <Col>
<CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}> <CodeGroup title="Request" tag="POST" label="/meta" targetCode={`curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \\\n-H 'Authorization: Bearer {api_key}'`}>
```bash {{ title: 'cURL' }} ```bash {{ title: 'cURL' }}
curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \ curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \
-H 'Authorization: Bearer {api_key}' -H 'Authorization: Bearer {api_key}'
``` ```
......
'use client' 'use client'
import React, { useEffect } from 'react' import React from 'react'
import cn from 'classnames' import cn from 'classnames'
import { useRouter } from 'next/navigation' import { useRouter } from 'next/navigation'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
import useSWR from 'swr'
import Toast from '../../base/toast' import Toast from '../../base/toast'
import s from './style.module.css' import s from './style.module.css'
import ExploreContext from '@/context/explore-context' import ExploreContext from '@/context/explore-context'
import type { App, AppCategory } from '@/models/explore' import type { App } from '@/models/explore'
import Category from '@/app/components/explore/category' import Category from '@/app/components/explore/category'
import AppCard from '@/app/components/explore/app-card' import AppCard from '@/app/components/explore/app-card'
import { fetchAppDetail, fetchAppList } from '@/service/explore' import { fetchAppDetail, fetchAppList } from '@/service/explore'
import { createApp } from '@/service/apps' import { createApp } from '@/service/apps'
import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import CreateAppModal from '@/app/components/explore/create-app-modal' import CreateAppModal from '@/app/components/explore/create-app-modal'
import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal'
import Loading from '@/app/components/base/loading' import Loading from '@/app/components/base/loading'
...@@ -36,32 +38,43 @@ const Apps = ({ ...@@ -36,32 +38,43 @@ const Apps = ({
const { isCurrentWorkspaceManager } = useAppContext() const { isCurrentWorkspaceManager } = useAppContext()
const router = useRouter() const router = useRouter()
const { hasEditPermission } = useContext(ExploreContext) const { hasEditPermission } = useContext(ExploreContext)
const [currCategory, setCurrCategory] = React.useState<AppCategory | ''>('') const allCategoriesEn = t('explore.apps.allCategories')
const [allList, setAllList] = React.useState<App[]>([]) const [currCategory, setCurrCategory] = useTabSearchParams({
const [isLoaded, setIsLoaded] = React.useState(false) defaultTab: allCategoriesEn,
})
const {
data: { categories, allList },
} = useSWR(
['/explore/apps'],
() =>
fetchAppList().then(({ categories, recommended_apps }) => ({
categories,
allList: recommended_apps.sort((a, b) => a.position - b.position),
})),
{
fallbackData: {
categories: [],
allList: [],
},
},
)
const currList = (() => { const currList = (() => {
if (currCategory === '') if (currCategory === allCategoriesEn)
return allList return allList
return allList.filter(item => item.category === currCategory) return allList.filter(item => item.category === currCategory)
})() })()
const [categories, setCategories] = React.useState<AppCategory[]>([])
useEffect(() => {
(async () => {
const { categories, recommended_apps }: any = await fetchAppList()
const sortedRecommendedApps = [...recommended_apps]
sortedRecommendedApps.sort((a, b) => a.position - b.position) // position from small to big
setCategories(categories)
setAllList(sortedRecommendedApps)
setIsLoaded(true)
})()
}, [])
const [currApp, setCurrApp] = React.useState<App | null>(null) const [currApp, setCurrApp] = React.useState<App | null>(null)
const [isShowCreateModal, setIsShowCreateModal] = React.useState(false) const [isShowCreateModal, setIsShowCreateModal] = React.useState(false)
const onCreate: CreateAppModalProps['onConfirm'] = async ({ name, icon, icon_background, description }) => { const onCreate: CreateAppModalProps['onConfirm'] = async ({
const { app_model_config: model_config } = await fetchAppDetail(currApp?.app.id as string) name,
icon,
icon_background,
}) => {
const { app_model_config: model_config } = await fetchAppDetail(
currApp?.app.id as string,
)
try { try {
const app = await createApp({ const app = await createApp({
...@@ -78,17 +91,20 @@ const Apps = ({ ...@@ -78,17 +91,20 @@ const Apps = ({
message: t('app.newApp.appCreated'), message: t('app.newApp.appCreated'),
}) })
localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1')
router.push(`/app/${app.id}/${isCurrentWorkspaceManager ? 'configuration' : 'overview'}`) router.push(
`/app/${app.id}/${isCurrentWorkspaceManager ? 'configuration' : 'overview'
}`,
)
} }
catch (e) { catch (e) {
Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') })
} }
} }
if (!isLoaded) { if (!categories) {
return ( return (
<div className='flex h-full items-center'> <div className="flex h-full items-center">
<Loading type='area' /> <Loading type="area" />
</div> </div>
) )
} }
......
...@@ -138,16 +138,12 @@ export default function AccountSetting({ ...@@ -138,16 +138,12 @@ export default function AccountSetting({
] ]
const scrollRef = useRef<HTMLDivElement>(null) const scrollRef = useRef<HTMLDivElement>(null)
const [scrolled, setScrolled] = useState(false) const [scrolled, setScrolled] = useState(false)
const scrollHandle = (e: Event) => {
if ((e.target as HTMLDivElement).scrollTop > 0)
setScrolled(true)
else
setScrolled(false)
}
useEffect(() => { useEffect(() => {
const targetElement = scrollRef.current const targetElement = scrollRef.current
const scrollHandle = (e: Event) => {
const userScrolled = (e.target as HTMLDivElement).scrollTop > 0
setScrolled(userScrolled)
}
targetElement?.addEventListener('scroll', scrollHandle) targetElement?.addEventListener('scroll', scrollHandle)
return () => { return () => {
targetElement?.removeEventListener('scroll', scrollHandle) targetElement?.removeEventListener('scroll', scrollHandle)
......
...@@ -54,7 +54,7 @@ const Header = () => { ...@@ -54,7 +54,7 @@ const Header = () => {
</div>} </div>}
{!isMobile && <> {!isMobile && <>
<Link href="/apps" className='flex items-center mr-4'> <Link href="/apps" className='flex items-center mr-4'>
<LogoSite /> <LogoSite className='object-contain' />
</Link> </Link>
<GithubStar /> <GithubStar />
</>} </>}
......
...@@ -3,11 +3,13 @@ import type { FC } from 'react' ...@@ -3,11 +3,13 @@ import type { FC } from 'react'
import React from 'react' import React from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from 'classnames' import cn from 'classnames'
import Tooltip from '../../base/tooltip'
import { HelpCircle } from '../../base/icons/src/vender/line/general'
import type { Credential } from '@/app/components/tools/types' import type { Credential } from '@/app/components/tools/types'
import Drawer from '@/app/components/base/drawer-plus' import Drawer from '@/app/components/base/drawer-plus'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Radio from '@/app/components/base/radio/ui' import Radio from '@/app/components/base/radio/ui'
import { AuthType } from '@/app/components/tools/types' import { AuthHeaderPrefix, AuthType } from '@/app/components/tools/types'
type Props = { type Props = {
credential: Credential credential: Credential
...@@ -18,9 +20,9 @@ const keyClassNames = 'py-2 leading-5 text-sm font-medium text-gray-900' ...@@ -18,9 +20,9 @@ const keyClassNames = 'py-2 leading-5 text-sm font-medium text-gray-900'
type ItemProps = { type ItemProps = {
text: string text: string
value: AuthType value: AuthType | AuthHeaderPrefix
isChecked: boolean isChecked: boolean
onClick: (value: AuthType) => void onClick: (value: AuthType | AuthHeaderPrefix) => void
} }
const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => { const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => {
...@@ -31,7 +33,6 @@ const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => { ...@@ -31,7 +33,6 @@ const SelectItem: FC<ItemProps> = ({ text, value, isChecked, onClick }) => {
> >
<Radio isChecked={isChecked} /> <Radio isChecked={isChecked} />
<div className='text-sm font-normal text-gray-900'>{text}</div> <div className='text-sm font-normal text-gray-900'>{text}</div>
</div> </div>
) )
} }
...@@ -43,6 +44,7 @@ const ConfigCredential: FC<Props> = ({ ...@@ -43,6 +44,7 @@ const ConfigCredential: FC<Props> = ({
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const [tempCredential, setTempCredential] = React.useState<Credential>(credential) const [tempCredential, setTempCredential] = React.useState<Credential>(credential)
return ( return (
<Drawer <Drawer
isShow isShow
...@@ -62,20 +64,59 @@ const ConfigCredential: FC<Props> = ({ ...@@ -62,20 +64,59 @@ const ConfigCredential: FC<Props> = ({
text={t('tools.createTool.authMethod.types.none')} text={t('tools.createTool.authMethod.types.none')}
value={AuthType.none} value={AuthType.none}
isChecked={tempCredential.auth_type === AuthType.none} isChecked={tempCredential.auth_type === AuthType.none}
onClick={value => setTempCredential({ ...tempCredential, auth_type: value })} onClick={value => setTempCredential({ ...tempCredential, auth_type: value as AuthType })}
/> />
<SelectItem <SelectItem
text={t('tools.createTool.authMethod.types.api_key')} text={t('tools.createTool.authMethod.types.api_key')}
value={AuthType.apiKey} value={AuthType.apiKey}
isChecked={tempCredential.auth_type === AuthType.apiKey} isChecked={tempCredential.auth_type === AuthType.apiKey}
onClick={value => setTempCredential({ ...tempCredential, auth_type: value })} onClick={value => setTempCredential({
...tempCredential,
auth_type: value as AuthType,
api_key_header: tempCredential.api_key_header || 'Authorization',
api_key_value: tempCredential.api_key_value || '',
api_key_header_prefix: tempCredential.api_key_header_prefix || AuthHeaderPrefix.custom,
})}
/> />
</div> </div>
</div> </div>
{tempCredential.auth_type === AuthType.apiKey && ( {tempCredential.auth_type === AuthType.apiKey && (
<> <>
<div className={keyClassNames}>{t('tools.createTool.authHeaderPrefix.title')}</div>
<div className='flex space-x-3'>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.basic')}
value={AuthHeaderPrefix.basic}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.basic}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.bearer')}
value={AuthHeaderPrefix.bearer}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.bearer}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
<SelectItem
text={t('tools.createTool.authHeaderPrefix.types.custom')}
value={AuthHeaderPrefix.custom}
isChecked={tempCredential.api_key_header_prefix === AuthHeaderPrefix.custom}
onClick={value => setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })}
/>
</div>
<div> <div>
<div className={keyClassNames}>{t('tools.createTool.authMethod.key')}</div> <div className='flex items-center h-8 text-[13px] font-medium text-gray-900'>
{t('tools.createTool.authMethod.key')}
<Tooltip
selector='model-page-system-reasoning-model-tip'
htmlContent={
<div className='w-[261px] text-gray-500'>
{t('tools.createTool.authMethod.keyTooltip')}
</div>
}
>
<HelpCircle className='ml-0.5 w-[14px] h-[14px] text-gray-400'/>
</Tooltip>
</div>
<input <input
value={tempCredential.api_key_header} value={tempCredential.api_key_header}
onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })} onChange={e => setTempCredential({ ...tempCredential, api_key_header: e.target.value })}
...@@ -83,7 +124,6 @@ const ConfigCredential: FC<Props> = ({ ...@@ -83,7 +124,6 @@ const ConfigCredential: FC<Props> = ({
placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!} placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!}
/> />
</div> </div>
<div> <div>
<div className={keyClassNames}>{t('tools.createTool.authMethod.value')}</div> <div className={keyClassNames}>{t('tools.createTool.authMethod.value')}</div>
<input <input
......
...@@ -8,7 +8,7 @@ import { clone } from 'lodash-es' ...@@ -8,7 +8,7 @@ import { clone } from 'lodash-es'
import cn from 'classnames' import cn from 'classnames'
import { LinkExternal02, Settings01 } from '../../base/icons/src/vender/line/general' import { LinkExternal02, Settings01 } from '../../base/icons/src/vender/line/general'
import type { Credential, CustomCollectionBackend, CustomParamSchema, Emoji } from '../types' import type { Credential, CustomCollectionBackend, CustomParamSchema, Emoji } from '../types'
import { AuthType } from '../types' import { AuthHeaderPrefix, AuthType } from '../types'
import GetSchema from './get-schema' import GetSchema from './get-schema'
import ConfigCredentials from './config-credentials' import ConfigCredentials from './config-credentials'
import TestApi from './test-api' import TestApi from './test-api'
...@@ -37,6 +37,7 @@ const EditCustomCollectionModal: FC<Props> = ({ ...@@ -37,6 +37,7 @@ const EditCustomCollectionModal: FC<Props> = ({
const { t } = useTranslation() const { t } = useTranslation()
const isAdd = !payload const isAdd = !payload
const isEdit = !!payload const isEdit = !!payload
const [editFirst, setEditFirst] = useState(!isAdd) const [editFirst, setEditFirst] = useState(!isAdd)
const [paramsSchemas, setParamsSchemas] = useState<CustomParamSchema[]>(payload?.tools || []) const [paramsSchemas, setParamsSchemas] = useState<CustomParamSchema[]>(payload?.tools || [])
const [customCollection, setCustomCollection, getCustomCollection] = useGetState<CustomCollectionBackend>(isAdd const [customCollection, setCustomCollection, getCustomCollection] = useGetState<CustomCollectionBackend>(isAdd
...@@ -44,6 +45,8 @@ const EditCustomCollectionModal: FC<Props> = ({ ...@@ -44,6 +45,8 @@ const EditCustomCollectionModal: FC<Props> = ({
provider: '', provider: '',
credentials: { credentials: {
auth_type: AuthType.none, auth_type: AuthType.none,
api_key_header: 'Authorization',
api_key_header_prefix: AuthHeaderPrefix.basic,
}, },
icon: { icon: {
content: '🕵️', content: '🕵️',
......
...@@ -16,6 +16,7 @@ import EditCustomToolModal from './edit-custom-collection-modal' ...@@ -16,6 +16,7 @@ import EditCustomToolModal from './edit-custom-collection-modal'
import NoCustomTool from './info/no-custom-tool' import NoCustomTool from './info/no-custom-tool'
import NoSearchRes from './info/no-search-res' import NoSearchRes from './info/no-search-res'
import NoCustomToolPlaceholder from './no-custom-tool-placeholder' import NoCustomToolPlaceholder from './no-custom-tool-placeholder'
import { useTabSearchParams } from '@/hooks/use-tab-searchparams'
import TabSlider from '@/app/components/base/tab-slider' import TabSlider from '@/app/components/base/tab-slider'
import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools'
import type { AgentTool } from '@/types/app' import type { AgentTool } from '@/types/app'
...@@ -68,7 +69,9 @@ const Tools: FC<Props> = ({ ...@@ -68,7 +69,9 @@ const Tools: FC<Props> = ({
})() })()
const [query, setQuery] = useState('') const [query, setQuery] = useState('')
const [collectionType, setCollectionType] = useState<CollectionType>(collectionTypeOptions[0].value) const [collectionType, setCollectionType] = useTabSearchParams({
defaultTab: collectionTypeOptions[0].value,
})
const showCollectionList = (() => { const showCollectionList = (() => {
let typeFilteredList: Collection[] = [] let typeFilteredList: Collection[] = []
......
...@@ -3,7 +3,7 @@ import type { FC } from 'react' ...@@ -3,7 +3,7 @@ import type { FC } from 'react'
import React, { useEffect, useState } from 'react' import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import cn from 'classnames' import cn from 'classnames'
import { CollectionType, LOC } from '../types' import { AuthHeaderPrefix, AuthType, CollectionType, LOC } from '../types'
import type { Collection, CustomCollectionBackend, Tool } from '../types' import type { Collection, CustomCollectionBackend, Tool } from '../types'
import Loading from '../../base/loading' import Loading from '../../base/loading'
import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows' import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows'
...@@ -53,6 +53,10 @@ const ToolList: FC<Props> = ({ ...@@ -53,6 +53,10 @@ const ToolList: FC<Props> = ({
(async () => { (async () => {
if (collection.type === CollectionType.custom) { if (collection.type === CollectionType.custom) {
const res = await fetchCustomCollection(collection.name) const res = await fetchCustomCollection(collection.name)
if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) {
if (res.credentials.api_key_value)
res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom
}
setCustomCollection({ setCustomCollection({
...res, ...res,
provider: collection.name, provider: collection.name,
......
...@@ -9,10 +9,17 @@ export enum AuthType { ...@@ -9,10 +9,17 @@ export enum AuthType {
apiKey = 'api_key', apiKey = 'api_key',
} }
export enum AuthHeaderPrefix {
basic = 'basic',
bearer = 'bearer',
custom = 'custom',
}
export type Credential = { export type Credential = {
'auth_type': AuthType 'auth_type': AuthType
'api_key_header'?: string 'api_key_header'?: string
'api_key_value'?: string 'api_key_value'?: string
'api_key_header_prefix'?: AuthHeaderPrefix
} }
export enum CollectionType { export enum CollectionType {
......
...@@ -122,7 +122,7 @@ export const useCheckNotion = () => { ...@@ -122,7 +122,7 @@ export const useCheckNotion = () => {
const notionCode = searchParams.get('code') const notionCode = searchParams.get('code')
const notionError = searchParams.get('error') const notionError = searchParams.get('error')
const { data } = useSWR( const { data } = useSWR(
canBinding (canBinding && notionCode)
? `/oauth/data-source/binding/notion?code=${notionCode}` ? `/oauth/data-source/binding/notion?code=${notionCode}`
: null, : null,
fetchDataSourceNotionBinding, fetchDataSourceNotionBinding,
......
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
type UseTabSearchParamsOptions = {
defaultTab: string
routingBehavior?: 'push' | 'replace'
searchParamName?: string
}
/**
* Custom hook to manage tab state via URL search parameters in a Next.js application.
* This hook allows for syncing the active tab with the browser's URL, enabling bookmarking and sharing of URLs with a specific tab activated.
*
* @param {UseTabSearchParamsOptions} options Configuration options for the hook:
* - `defaultTab`: The tab to default to when no tab is specified in the URL.
* - `routingBehavior`: Optional. Determines how changes to the active tab update the browser's history ('push' or 'replace'). Default is 'push'.
* - `searchParamName`: Optional. The name of the search parameter that holds the tab state in the URL. Default is 'category'.
* @returns A tuple where the first element is the active tab and the second element is a function to set the active tab.
*/
export const useTabSearchParams = ({
defaultTab,
routingBehavior = 'push',
searchParamName = 'category',
}: UseTabSearchParamsOptions) => {
const router = useRouter()
const pathName = usePathname()
const searchParams = useSearchParams()
const activeTab = searchParams.get(searchParamName) || defaultTab
const setActiveTab = (newActiveTab: string) => {
router[routingBehavior](`${pathName}?${searchParamName}=${newActiveTab}`)
}
return [activeTab, setActiveTab] as const
}
...@@ -51,6 +51,7 @@ const translation = { ...@@ -51,6 +51,7 @@ const translation = {
authMethod: { authMethod: {
title: 'Authorization method', title: 'Authorization method',
type: 'Authorization type', type: 'Authorization type',
keyTooltip: 'Http Header Key, You can leave it with "Authorization" if you have no idea what it is or set it to a custom value',
types: { types: {
none: 'None', none: 'None',
api_key: 'API Key', api_key: 'API Key',
...@@ -60,6 +61,14 @@ const translation = { ...@@ -60,6 +61,14 @@ const translation = {
key: 'Key', key: 'Key',
value: 'Value', value: 'Value',
}, },
authHeaderPrefix: {
title: 'Auth Type',
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Privacy policy', privacyPolicy: 'Privacy policy',
privacyPolicyPlaceholder: 'Please enter privacy policy', privacyPolicyPlaceholder: 'Please enter privacy policy',
}, },
......
...@@ -58,6 +58,13 @@ const translation = { ...@@ -58,6 +58,13 @@ const translation = {
key: 'Chave', key: 'Chave',
value: 'Valor', value: 'Valor',
}, },
authHeaderPrefix: {
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Política de Privacidade', privacyPolicy: 'Política de Privacidade',
privacyPolicyPlaceholder: 'Digite a política de privacidade', privacyPolicyPlaceholder: 'Digite a política de privacidade',
}, },
......
...@@ -58,6 +58,13 @@ const translation = { ...@@ -58,6 +58,13 @@ const translation = {
key: 'Ключ', key: 'Ключ',
value: 'Значення', value: 'Значення',
}, },
authHeaderPrefix: {
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: 'Політика конфіденційності', privacyPolicy: 'Політика конфіденційності',
privacyPolicyPlaceholder: 'Введіть політику конфіденційності', privacyPolicyPlaceholder: 'Введіть політику конфіденційності',
}, },
......
...@@ -51,6 +51,7 @@ const translation = { ...@@ -51,6 +51,7 @@ const translation = {
authMethod: { authMethod: {
title: '鉴权方法', title: '鉴权方法',
type: '鉴权类型', type: '鉴权类型',
keyTooltip: 'HTTP 头部名称,如果你不知道是什么,可以将其保留为 Authorization 或设置为自定义值',
types: { types: {
none: '无', none: '无',
api_key: 'API Key', api_key: 'API Key',
...@@ -60,6 +61,14 @@ const translation = { ...@@ -60,6 +61,14 @@ const translation = {
key: '键', key: '键',
value: '值', value: '值',
}, },
authHeaderPrefix: {
title: '鉴权头部前缀',
types: {
basic: 'Basic',
bearer: 'Bearer',
custom: 'Custom',
},
},
privacyPolicy: '隐私协议', privacyPolicy: '隐私协议',
privacyPolicyPlaceholder: '请输入隐私协议', privacyPolicyPlaceholder: '请输入隐私协议',
}, },
......
...@@ -15,7 +15,7 @@ export type App = { ...@@ -15,7 +15,7 @@ export type App = {
app_id: string app_id: string
description: string description: string
copyright: string copyright: string
privacy_policy: string privacy_policy: string | null
category: AppCategory category: AppCategory
position: number position: number
is_listed: boolean is_listed: boolean
......
{ {
"name": "dify-web", "name": "dify-web",
"version": "0.5.6", "version": "0.5.7",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "next dev", "dev": "next dev",
......
import { del, get, patch, post } from './base' import { del, get, patch, post } from './base'
import type { App, AppCategory } from '@/models/explore'
export const fetchAppList = () => { export const fetchAppList = () => {
return get('/explore/apps') return get<{
categories: AppCategory[]
recommended_apps: App[]
}>('/explore/apps')
} }
export const fetchAppDetail = (id: string): Promise<any> => { export const fetchAppDetail = (id: string): Promise<any> => {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment