Commit f7a90f26 authored by Joel's avatar Joel

merge main

parents 3d825dcb dd961985
......@@ -21,6 +21,17 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>
<p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
......
......@@ -6,15 +6,15 @@ import click
from flask import current_app
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
from models.account import Tenant
from models.dataset import Dataset
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account
from models.provider import Provider, ProviderModel
......@@ -124,14 +124,15 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
@click.command('vdb-migrate', help='migrate vector db.')
def vdb_migrate():
"""
Migrate other vector database datas to Qdrant.
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start create qdrant indexes.', fg='green'))
click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0
config = current_app.config
vector_type = config.get('VECTOR_STORE')
page = 1
while True:
try:
......@@ -140,54 +141,101 @@ def create_qdrant_indexes():
except NotFound:
break
model_manager = ModelManager()
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except Exception:
continue
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
continue
if vector_type == "weaviate":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "qdrant":
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
click.echo('passed.')
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
try:
vector.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
vector.create(documents)
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
raise e
click.echo(f"Dataset {dataset.id} create successfully.")
db.session.add(dataset)
db.session.commit()
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
......@@ -196,4 +244,4 @@ def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(vdb_migrate)
......@@ -38,7 +38,9 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
......@@ -88,7 +90,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.6"
self.CURRENT_VERSION = "0.5.7"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
......@@ -261,8 +263,10 @@ class Config:
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
......
......@@ -124,19 +124,13 @@ class AppListApi(Resource):
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
if provider_model not in available_models_names:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_instance:
if not default_model_entity:
raise ProviderNotInitializeError(
"No Default System Reasoning Model available. Please configure "
"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
model_config_dict["model"]["provider"] = default_model_entity.provider
model_config_dict["model"]["name"] = default_model_entity.model
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
......
......@@ -178,7 +178,8 @@ class DataSourceNotionApi(Resource):
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id
)
text_docs = extractor.extract()
......@@ -208,7 +209,8 @@ class DataSourceNotionApi(Resource):
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type']
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
......
......@@ -298,7 +298,8 @@ class DatasetIndexingEstimateApi(Resource):
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type']
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
......
......@@ -455,7 +455,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type']
"notion_page_type": data_source_info['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=document.doc_form
)
......
from extensions.ext_database import db
from models.model import EndUser
def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
import json
from flask import current_app
from flask_restful import fields, marshal_with
from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider
class AppParameterApi(AppApiResource):
class AppParameterApi(Resource):
"""Resource for app variables."""
variable_fields = {
......@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields)
}
@validate_app_token
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
def get(self, app_model: App):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
......@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
}
}
class AppMetaApi(AppApiResource):
def get(self, app_model: App, end_user):
class AppMetaApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config
......
import logging
from flask import request
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
......@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig
from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
......@@ -30,8 +30,9 @@ from services.errors.audio import (
)
class AudioApi(AppApiResource):
def post(self, app_model: App, end_user):
class AudioApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
......@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError()
class TextApi(AppApiResource):
def post(self, app_model: App, end_user):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
......@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=args['text'],
end_user=args['user'],
end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)
......
......@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
......@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService
class CompletionApi(AppApiResource):
def post(self, app_model, end_user):
class CompletionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion':
raise AppUnavailableError()
......@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False
try:
......@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError()
class CompletionStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion':
raise AppUnavailableError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
class ChatApi(AppApiResource):
def post(self, app_model, end_user):
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
......@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
......@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
response = CompletionService.completion(
app_model=app_model,
......@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError()
class ChatStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
......
from flask import request
from flask_restful import marshal_with, reqparse
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService
class ConversationApi(AppApiResource):
class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource):
class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id):
def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try:
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
......@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource):
class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
......@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.rename(
app_model,
......
from flask import request
from flask_restful import marshal_with
from flask_restful import Resource, marshal_with
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService
class FileApi(AppApiResource):
class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model, end_user):
def post(self, app_model: App, end_user: EndUser):
file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file
if 'file' not in request.files:
......
from flask_restful import fields, marshal_with, reqparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from extensions.ext_database import db
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message
from models.model import App, EndUser
from services.message_service import MessageService
class MessageListApi(AppApiResource):
class MessageListApi(Resource):
feedback_fields = {
'rating': fields.String
}
......@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields))
}
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
......@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit'])
......@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(AppApiResource):
def post(self, app_model, end_user, message_id):
class MessageFeedbackApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError:
......@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
user=end_user,
message_id=message_id,
check_enabled=False
)
......
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
def validate_app_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
class WhereisUserArg(Enum):
"""
Enum for whereis_user_arg.
"""
QUERY = 'query'
JSON = 'json'
FORM = 'form'
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
......@@ -29,15 +47,34 @@ def validate_app_token(view=None):
if not app_model.enable_api:
raise NotFound()
return view(app_model, None, *args, **kwargs)
return decorated
kwargs['app_model'] = app_model
if view:
return decorator(view)
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")
if user_id:
user_id = str(user_id)
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
def cloud_edition_billing_resource_check(resource: str,
......@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
return api_token
class AppApiResource(Resource):
method_decorators = [validate_app_token]
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
"""
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource):
......
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param model_config:
:param messages:
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass
This diff is collapsed.
This diff is collapsed.
import json
import logging
from typing import cast
......@@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
......@@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
message_chain = self._init_message_chain(
message=message,
query=query
)
# init model instance
model_instance = ModelInstance(
......@@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables
})
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""
......
......@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
......
......@@ -175,7 +175,7 @@ class GenerateTaskPipeline:
'id': self._message.id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'answer': event.llm_result.message.content,
'answer': self._task_state.llm_result.message.content,
'metadata': {},
'created_at': int(self._message.created_at.timestamp())
}
......
......@@ -3,12 +3,12 @@ import logging
from typing import Optional, cast
import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.entity.embedding import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
......
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
import logging
from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (
AgentEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool
\ No newline at end of file
......@@ -59,7 +59,7 @@ class AnnotationReplyFeature:
documents = vector.search_by_vector(
query=query,
k=1,
top_k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
......
......@@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner):
for message in messages:
result.append(UserPromptMessage(content=message.query))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
tool_inputs = json.loads(agent_thought.tool_input)
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
tool_inputs = json.loads(agent_thought.tool_input)
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool,
tool_call_id=tool_call_id,
))
result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
tool_call_id=tool_call_id,
))
result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
return result
\ No newline at end of file
......@@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought='',
action_str='',
observation='',
action=None
action=None,
)
# publish agent thought if it's first iteration
......@@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought=message.content,
action_str='',
action=None,
observation=None
observation=None,
)
if message.tool_calls:
try:
......@@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
return agent_scratchpad
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
......@@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
prompt_message.content = system_message
overridden = True
break
# convert tool prompt messages to user prompt messages
for idx, prompt_message in enumerate(prompt_messages):
if isinstance(prompt_message, ToolPromptMessage):
prompt_messages[idx] = UserPromptMessage(
content=prompt_message.content
)
if not overridden:
prompt_messages.insert(0, SystemPromptMessage(
......
......@@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
......
......@@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
......
......@@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
......
import enum
import logging
from typing import Optional, Union
......@@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
......@@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigEntity
......@@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]
......
......@@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
......
......@@ -104,37 +104,17 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
restrict_models=trial_models
)
quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(
restrict_models=[
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
restrict_models=paid_models
)
quotas.append(paid_quota)
......@@ -258,3 +238,11 @@ class HostingConfiguration:
return HostedModerationConfig(
enabled=False
)
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
models_list = models_str.split(",") if models_str else []
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
model_name.strip()]
......@@ -365,8 +365,9 @@ class IndexingRunner:
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['notion_page_type'],
"document": dataset_document
"notion_page_type": data_source_info['type'],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id
},
document_model=dataset_document.doc_form
)
......@@ -664,6 +665,7 @@ class IndexingRunner:
)
# load index
index_processor.load(dataset, chunk_documents)
db.session.add(dataset)
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
......
......@@ -20,7 +20,7 @@
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md)
​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)
- 可选择的模型列表展示
......@@ -86,4 +86,4 @@ Model Runtime 分三层:
![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
\ No newline at end of file
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
......@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
'min': 1,
'max': 2048,
'precision': 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
},
'required': False,
'options': ['JSON', 'XML'],
}
}
\ No newline at end of file
......@@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
......
......@@ -262,23 +262,23 @@ class AIModel(ABC):
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max:
if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min:
if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.precision:
if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision:
if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required:
if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help:
if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
)
if not parameter_rule.help.en_US:
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans:
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError:
pass
......
......@@ -6,6 +6,7 @@
- bedrock
- togetherai
- ollama
- mistralai
- replicate
- huggingface_hub
- zhipuai
......
......@@ -27,6 +27,8 @@ parameter_rules:
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'
......
......@@ -27,6 +27,8 @@ parameter_rules:
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '8.00'
output: '24.00'
......
......@@ -26,6 +26,8 @@ parameter_rules:
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
......
......@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params
from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
......@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class AnthropicLargeLanguageModel(LargeLanguageModel):
<instructions>
{{instructions}}
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
......@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
self._transform_json_prompts(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(
content=f"```{response_format}\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
......
......@@ -27,6 +27,8 @@ parameter_rules:
default: 2048
min: 1
max: 2048
- name: response_format
use_template: response_format
pricing:
input: '0.00'
output: '0.00'
......
......@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
......@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
......
- open-mistral-7b
- open-mixtral-8x7b
- mistral-small-latest
- mistral-medium-latest
- mistral-large-latest
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# mistral dose not support user/stop arguments
stop = []
user = None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
model: mistral-large-latest
label:
zh_Hans: mistral-large-latest
en_US: mistral-large-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD
model: mistral-medium-latest
label:
zh_Hans: mistral-medium-latest
en_US: mistral-medium-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.0027'
output: '0.0081'
unit: '0.001'
currency: USD
model: mistral-small-latest
label:
zh_Hans: mistral-small-latest
en_US: mistral-small-latest
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.002'
output: '0.006'
unit: '0.001'
currency: USD
model: open-mistral-7b
label:
zh_Hans: open-mistral-7b
en_US: open-mistral-7b
model_type: llm
features:
- agent-thought
model_properties:
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 2048
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.00025'
output: '0.00025'
unit: '0.001'
currency: USD
model: open-mixtral-8x7b
label:
zh_Hans: open-mixtral-8x7b
en_US: open-mixtral-8x7b
model_type: llm
features:
- agent-thought
model_properties:
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 8000
- name: safe_prompt
defulat: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.0007'
output: '0.0007'
unit: '0.001'
currency: USD
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class MistralAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='open-mistral-7b',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
provider: mistralai
label:
en_US: MistralAI
description:
en_US: Models provided by MistralAI, such as open-mistral-7b and mistral-large-latest.
zh_Hans: MistralAI 提供的模型,例如 open-mistral-7b 和 mistral-large-latest。
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from MistralAI
zh_Hans: 从 MistralAI 获取 API Key
url:
en_US: https://console.mistral.ai/api-keys/
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
......@@ -13,6 +13,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
user = user[:32] if user else None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
......
......@@ -24,6 +24,18 @@ parameter_rules:
default: 512
min: 1
max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.0005'
output: '0.0015'
......
......@@ -24,6 +24,8 @@ parameter_rules:
default: 512
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '0.0015'
output: '0.002'
......
......@@ -24,6 +24,18 @@ parameter_rules:
default: 512
min: 1
max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.001'
output: '0.002'
......
......@@ -24,6 +24,8 @@ parameter_rules:
default: 512
min: 1
max: 16385
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.004'
......
......@@ -24,6 +24,8 @@ parameter_rules:
default: 512
min: 1
max: 16385
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.004'
......
......@@ -21,6 +21,8 @@ parameter_rules:
default: 512
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '0.0015'
output: '0.002'
......
......@@ -24,6 +24,18 @@ parameter_rules:
default: 512
min: 1
max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.001'
output: '0.002'
......
......@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
......@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
logger = logging.getLogger(__name__)
OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
"""
......@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
user=user
)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
# handle fine tune remote models
base_model = model
if model.startswith('ft:'):
base_model = model.split(':')[1]
# get model mode
model_mode = self.get_model_mode(base_model, credentials)
# transform response format
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
stop = stop or []
if model_mode == LLMMode.CHAT:
# chat model
self._transform_chat_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
else:
self._transform_completion_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def _transform_completion_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# override the last user message
user_message = None
for i in range(len(prompt_messages) - 1, -1, -1):
if isinstance(prompt_messages[i], UserPromptMessage):
user_message = prompt_messages[i]
break
if user_message:
if prompt_messages[i].content[-11:] == 'Assistant: ':
# now we are in the chat app, remove the last assistant message
prompt_messages[i].content = prompt_messages[i].content[:-11]
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
else:
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"\n```{response_format}\n"
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
......
......@@ -13,6 +13,7 @@ from dashscope.common.error import (
)
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
......@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
-> LLMResult | Generator:
"""
Wrapper for code block mode
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
mode = self.get_model_mode(model, credentials)
if mode == LLMMode.CHAT:
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
else:
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=response
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
......@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
"""
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
extra_model_kwargs['stop'] = stop
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
......@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
params = {
'model': model,
**model_parameters,
**credentials_kwargs
**credentials_kwargs,
**extra_model_kwargs,
}
mode = self.get_model_mode(model, credentials)
......
......@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
- name: response_format
use_template: response_format
......@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
- name: response_format
use_template: response_format
......@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
- name: response_format
use_template: response_format
......@@ -56,6 +56,8 @@ parameter_rules:
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing:
input: '0.02'
output: '0.02'
......
......@@ -57,6 +57,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
- name: response_format
use_template: response_format
pricing:
input: '0.008'
output: '0.008'
......
......@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
......
......@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
......
......@@ -25,3 +25,5 @@ parameter_rules:
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
......@@ -34,3 +34,5 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false
- name: response_format
use_template: response_format
from collections.abc import Generator
from typing import cast
from typing import Optional, Union, cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
......@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
RateLimitReachedError,
)
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class ErnieBotLarguageModel(LargeLanguageModel):
<instructions>
{{instructions}}
</instructions>
You should also complete the text started with ``` but not tell ``` directly.
"""
class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
......@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel):
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
response_format = model_parameters['response_format']
stop = stop or []
self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format)
model_parameters.pop('response_format')
if stream:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
)
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts to model prompts
"""
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += "\n```JSON\n{\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content="```JSON\n{\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
# tools is not supported yet
......
......@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
And you should always end the block with a "```" to indicate the end of the JSON object.
<instructions>
{{instructions}}
</instructions>
```JSON"""
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
......@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model
# stop = stop or []
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
# def _transform_json_prompts(self, model: str, credentials: dict,
# prompt_messages: list[PromptMessage], model_parameters: dict,
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
# stream: bool = True, user: str | None = None) \
# -> None:
# """
# Transform json prompts to model prompts
# """
# if "}\n\n" not in stop:
# stop.append("}\n\n")
# # check if there is a system message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# # override the system message
# prompt_messages[0] = SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
# )
# else:
# # insert the system message
# prompt_messages.insert(0, SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
# ))
# # check if the last message is a user message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# # add ```JSON\n to the last message
# prompt_messages[-1].content += "\n```JSON\n"
# else:
# # append a user message
# prompt_messages.append(UserPromptMessage(
# content="```JSON\n"
# ))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
......@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
"""
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
extra_model_kwargs['stop'] = stop
client = ZhipuAI(
api_key=credentials_kwargs['api_key']
......@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
]
if stream:
response = client.chat.completions.create(stream=stream, **params)
response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs)
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
response = client.chat.completions.create(**params)
response = client.chat.completions.create(**params, **extra_model_kwargs)
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
def _handle_generate_response(self, model: str,
......
from typing import Any, cast
from typing import Any
from flask import current_app
......@@ -14,7 +14,7 @@ class Keyword:
self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword:
config = cast(dict, current_app.config)
config = current_app.config
keyword_type = config.get('KEYWORD_STORE')
if not keyword_type:
......@@ -25,7 +25,7 @@ class Keyword:
dataset=self._dataset
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
raise ValueError(f"Keyword store {keyword_type} is not supported.")
def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs)
......
......@@ -2,7 +2,6 @@ import threading
from typing import Optional
from flask import Flask, current_app
from flask_login import current_user
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
......@@ -27,6 +26,11 @@ class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
threads = []
# retrieval_model source with keyword
......@@ -35,7 +39,8 @@ class RetrievalService:
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k
'top_k': top_k,
'all_documents': all_documents
})
threads.append(keyword_thread)
keyword_thread.start()
......@@ -73,7 +78,7 @@ class RetrievalService:
thread.join()
if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
......@@ -96,7 +101,7 @@ class RetrievalService:
documents = keyword.search(
query,
k=top_k
top_k=top_k
)
all_documents.extend(documents)
......@@ -116,7 +121,7 @@ class RetrievalService:
documents = vector.search_by_vector(
query,
search_type='similarity_score_threshold',
k=top_k,
top_k=top_k,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
......
......@@ -7,4 +7,4 @@ class Field(Enum):
GROUP_KEY = "group_id"
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = " id"
PRIMARY_KEY = "id"
......@@ -124,12 +124,23 @@ class MilvusVector(BaseVector):
def delete_by_ids(self, doc_ids: list[str]) -> None:
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
from pymilvus import utility
utility.drop_collection(self._collection_name, None)
utility.drop_collection(self._collection_name, None, using=alias)
def text_exists(self, id: str) -> bool:
......
from typing import Any, cast
import json
from typing import Any
from flask import current_app
......@@ -22,7 +23,7 @@ class Vector:
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
config = cast(dict, current_app.config)
config = current_app.config
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
......@@ -39,6 +40,11 @@ class Vector:
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
......@@ -66,6 +72,13 @@ class Vector:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
if not self._dataset.index_struct_dict:
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return QdrantVector(
collection_name=collection_name,
group_id=self._dataset.id,
......@@ -84,6 +97,11 @@ class Vector:
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
......
......@@ -127,7 +127,10 @@ class WeaviateVector(BaseVector):
)
def delete(self):
self._client.schema.delete_class(self._collection_name)
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool:
collection_name = self._collection_name
......@@ -147,10 +150,11 @@ class WeaviateVector(BaseVector):
return True
def delete_by_ids(self, ids: list[str]) -> None:
self._client.data_object.delete(
ids,
class_name=self._collection_name
)
for uuid in ids:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
......
......@@ -12,6 +12,7 @@ class NotionInfo(BaseModel):
notion_obj_id: str
notion_page_type: str
document: Document = None
tenant_id: str
class Config:
arbitrary_types_allowed = True
......
......@@ -132,7 +132,8 @@ class ExtractProcessor:
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document
document_model=extract_setting.notion_info.document,
tenant_id=extract_setting.notion_info.tenant_id,
)
return extractor.extract()
else:
......
"""Abstract interface for document loader implementations."""
from typing import Optional
from bs4 import BeautifulSoup
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
from core.rag.models.document import Document
class HtmlExtractor(BaseExtractor):
"""Load html files.
"""
Load html files.
Args:
......@@ -15,57 +16,19 @@ class HtmlExtractor(BaseExtractor):
"""
def __init__(
self,
file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None,
csv_args: Optional[dict] = None,
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column
self.csv_args = csv_args or {}
def extract(self) -> list[Document]:
"""Load data into document objects."""
try:
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self._file_path}") from e
return docs
return [Document(page_content=self._load_as_text())]
def _read_from_file(self, csvfile) -> list[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return docs
return text
\ No newline at end of file
......@@ -4,7 +4,6 @@ from typing import Any, Optional
import requests
from flask import current_app
from flask_login import current_user
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
......@@ -30,8 +29,10 @@ class NotionExtractor(BaseExtractor):
notion_workspace_id: str,
notion_obj_id: str,
notion_page_type: str,
tenant_id: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None
notion_access_token: Optional[str] = None,
):
self._notion_access_token = None
self._document_model = document_model
......@@ -41,7 +42,7 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id)
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
......
import base64
import hashlib
import hmac
import json
import queue
import ssl
from datetime import datetime
from time import mktime
from typing import Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
import websocket
class SparkLLMClient:
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
domain = 'spark-api.xf-yun.com'
endpoint = 'chat'
if api_domain:
domain = api_domain
if model_name == 'spark-v3':
endpoint = 'multimodal'
model_api_configs = {
'spark': {
'version': 'v1.1',
'chat_domain': 'general'
},
'spark-v2': {
'version': 'v2.1',
'chat_domain': 'generalv2'
},
'spark-v3': {
'version': 'v3.1',
'chat_domain': 'generalv3'
},
'spark-v3.5': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
}
}
api_version = model_api_configs[model_name]['version']
self.chat_domain = model_api_configs[model_name]['chat_domain']
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,
urlparse(self.api_base).path,
self.api_base,
api_key,
api_secret
)
self.queue = queue.Queue()
self.blocking_message = ''
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
# generate timestamp by RFC1123
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + path + " HTTP/1.1"
# encrypt using hmac-sha256
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
}
# generate url
url = api_base + '?' + urlencode(v)
return url
def run(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None, streaming: bool = False):
websocket.enableTrace(False)
ws = websocket.WebSocketApp(
self.ws_url,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
on_open=self.on_open
)
ws.messages = messages
ws.user_id = user_id
ws.model_kwargs = model_kwargs
ws.streaming = streaming
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_error(self, ws, error):
self.queue.put({
'status_code': error.status_code,
'error': error.resp_body.decode('utf-8')
})
ws.close()
def on_close(self, ws, close_status_code, close_reason):
self.queue.put({'done': True})
def on_open(self, ws):
self.blocking_message = ''
data = json.dumps(self.gen_params(
messages=ws.messages,
user_id=ws.user_id,
model_kwargs=ws.model_kwargs
))
ws.send(data)
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
self.queue.put({
'status_code': 400,
'error': f"Code: {code}, Error: {data['header']['message']}"
})
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
if ws.streaming:
self.queue.put({'data': content})
else:
self.blocking_message += content
if status == 2:
if not ws.streaming:
self.queue.put({'data': self.blocking_message})
ws.close()
def gen_params(self, messages: list, user_id: str,
model_kwargs: Optional[dict] = None) -> dict:
data = {
"header": {
"app_id": self.app_id,
"uid": user_id
},
"parameter": {
"chat": {
"domain": self.chat_domain
}
},
"payload": {
"message": {
"text": messages
}
}
}
if model_kwargs:
data['parameter']['chat'].update(model_kwargs)
return data
def subscribe(self):
while True:
content = self.queue.get()
if 'error' in content:
if content['status_code'] == 401:
raise SparkError('[Spark] The credentials you provided are incorrect. '
'Please double-check and fill them in again.')
elif content['status_code'] == 403:
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
"Please try again after obtaining the necessary permissions.")
else:
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
if 'data' not in content:
break
yield content
class SparkError(Exception):
pass
from datetime import datetime
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class DatetimeToolInput(BaseModel):
type: str = Field(..., description="Type for current time, must be: datetime.")
class DatetimeTool(BaseTool):
"""Tool for querying current datetime."""
name: str = "current_datetime"
args_schema: type[BaseModel] = DatetimeToolInput
description: str = "A tool when you want to get the current date, time, week, month or year, " \
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
def _run(self, type: str) -> str:
# get current time
current_time = datetime.utcnow()
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
import base64
from abc import ABC, abstractmethod
from typing import Optional
from extensions.ext_database import db
from libs import rsa
from models.account import Tenant
from models.tool import ToolProvider, ToolProviderName
class BaseToolProvider(ABC):
def __init__(self, tenant_id: str):
self.tenant_id = tenant_id
@abstractmethod
def get_provider_name(self) -> ToolProviderName:
raise NotImplementedError
@abstractmethod
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_to_func_kwargs(self) -> Optional[dict]:
raise NotImplementedError
@abstractmethod
def credentials_validate(self, credentials: dict):
raise NotImplementedError
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
"""
Returns the Provider instance for the given tenant_id and tool_name.
"""
query = db.session.query(ToolProvider).filter(
ToolProvider.tenant_id == self.tenant_id,
ToolProvider.tool_name == self.get_provider_name().value
)
if must_enabled:
query = query.filter(ToolProvider.is_enabled == True)
return query.first()
def encrypt_token(self, token) -> str:
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
if obfuscated:
return self._obfuscated_token(token)
return token
def _obfuscated_token(self, token: str) -> str:
return token[:6] + '*' * (len(token) - 8) + token[-2:]
class ToolValidateFailedError(Exception):
description = "Tool Provider Validate failed"
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment