Unverified Commit 9c486d44 authored by Yeuoly's avatar Yeuoly

Merge branch 'main' into feat/enterprise

parents 138efe9e f1cbd550
...@@ -21,6 +21,17 @@ ...@@ -21,6 +21,17 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a> <img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p> </p>
<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>
<p align="center"> <p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank"> <a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
......
...@@ -134,3 +134,5 @@ UNSTRUCTURED_API_URL= ...@@ -134,3 +134,5 @@ UNSTRUCTURED_API_URL=
SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL= SSRF_PROXY_HTTPS_URL=
BATCH_UPLOAD_LIMIT=10
\ No newline at end of file
...@@ -6,15 +6,15 @@ import click ...@@ -6,15 +6,15 @@ import click
from flask import current_app from flask import current_app
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding from core.rag.datasource.vdb.vector_factory import Vector
from core.model_manager import ModelManager from core.rag.models.document import Document
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email as email_validate from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import Tenant from models.account import Tenant
from models.dataset import Dataset from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account from models.model import Account
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
...@@ -124,14 +124,15 @@ def reset_encrypt_key_pair(): ...@@ -124,14 +124,15 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.') @click.command('vdb-migrate', help='migrate vector db.')
def create_qdrant_indexes(): def vdb_migrate():
""" """
Migrate other vector database datas to Qdrant. Migrate vector database datas to target vector database .
""" """
click.echo(click.style('Start create qdrant indexes.', fg='green')) click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0 create_count = 0
config = current_app.config
vector_type = config.get('VECTOR_STORE')
page = 1 page = 1
while True: while True:
try: try:
...@@ -140,54 +141,101 @@ def create_qdrant_indexes(): ...@@ -140,54 +141,101 @@ def create_qdrant_indexes():
except NotFound: except NotFound:
break break
model_manager = ModelManager()
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
if dataset.index_struct_dict: try:
if dataset.index_struct_dict['type'] != 'qdrant': click.echo('Create dataset vdb index: {}'.format(dataset.id))
try: if dataset.index_struct_dict:
click.echo('Create dataset qdrant index: {}'.format(dataset.id)) if dataset.index_struct_dict['type'] == vector_type:
try: continue
embedding_model = model_manager.get_model_instance( if vector_type == "weaviate":
tenant_id=dataset.tenant_id, dataset_id = dataset.id
provider=dataset.embedding_model_provider, collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
model_type=ModelType.TEXT_EMBEDDING, index_struct_dict = {
model=dataset.embedding_model "type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
) }
except Exception: dataset.index_struct = json.dumps(index_struct_dict)
continue elif vector_type == "qdrant":
embeddings = CacheEmbedding(embedding_model) if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
index = QdrantVectorIndex( if dataset_collection_binding:
dataset=dataset, collection_name = dataset_collection_binding.collection_name
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
else: else:
click.echo('passed.') raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
try:
vector.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
vector.create(documents)
except Exception as e: except Exception as e:
click.echo( raise e
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.echo(f"Dataset {dataset.id} create successfully.")
fg='red')) db.session.add(dataset)
continue db.session.commit()
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
...@@ -196,4 +244,4 @@ def register_commands(app): ...@@ -196,4 +244,4 @@ def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(create_qdrant_indexes) app.cli.add_command(vdb_migrate)
...@@ -38,7 +38,9 @@ DEFAULTS = { ...@@ -38,7 +38,9 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False', 'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
'HOSTED_OPENAI_PAID_ENABLED': 'False', 'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
...@@ -56,6 +58,8 @@ DEFAULTS = { ...@@ -56,6 +58,8 @@ DEFAULTS = {
'BILLING_ENABLED': 'False', 'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False', 'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify', 'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20
} }
...@@ -187,7 +191,7 @@ class Config: ...@@ -187,7 +191,7 @@ class Config:
# Currently, only support: qdrant, milvus, zilliz, weaviate # Currently, only support: qdrant, milvus, zilliz, weaviate
# ------------------------ # ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE') self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
# qdrant settings # qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
...@@ -264,8 +268,10 @@ class Config: ...@@ -264,8 +268,10 @@ class Config:
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED') self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
...@@ -290,6 +296,8 @@ class Config: ...@@ -290,6 +296,8 @@ class Config:
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED') self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO') self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
class CloudEditionConfig(Config): class CloudEditionConfig(Config):
......
...@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound ...@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required from libs.login import login_required
...@@ -173,14 +174,15 @@ class DataSourceNotionApi(Resource): ...@@ -173,14 +174,15 @@ class DataSourceNotionApi(Resource):
if not data_source_binding: if not data_source_binding:
raise NotFound('Data source binding not found.') raise NotFound('Data source binding not found.')
loader = NotionLoader( extractor = NotionExtractor(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id
) )
text_docs = loader.load() text_docs = extractor.extract()
return { return {
'content': "\n".join([doc.page_content for doc in text_docs]) 'content': "\n".join([doc.page_content for doc in text_docs])
}, 200 }, 200
...@@ -192,11 +194,31 @@ class DataSourceNotionApi(Resource): ...@@ -192,11 +194,31 @@ class DataSourceNotionApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
notion_info_list = args['notion_info_list']
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'])
return response, 200 return response, 200
......
...@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError ...@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
...@@ -178,9 +179,9 @@ class DatasetApi(Resource): ...@@ -178,9 +179,9 @@ class DatasetApi(Resource):
location='json', store_missing=False, location='json', store_missing=False,
type=_validate_description_length) type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json', parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, nullable=True,
help='Invalid indexing technique.') help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=( parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.') 'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
...@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, parser.add_argument('indexing_technique', type=str, required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, location='json') nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
...@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
extract_settings = []
if args['info_list']['data_source_type'] == 'upload_file': if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args['info_list']['file_info_list']['file_ids'] file_ids = args['info_list']['file_info_list']['file_ids']
file_details = db.session.query(UploadFile).filter( file_details = db.session.query(UploadFile).filter(
...@@ -278,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource): ...@@ -278,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource):
if file_details is None: if file_details is None:
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() if file_details:
for file_detail in file_details:
try: extract_setting = ExtractSetting(
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, datasource_type="upload_file",
args['process_rule'], args['doc_form'], upload_file=file_detail,
args['doc_language'], args['dataset_id'], document_model=args['doc_form']
args['indexing_technique']) )
except LLMBadRequestError: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif args['info_list']['data_source_type'] == 'notion_import': elif args['info_list']['data_source_type'] == 'notion_import':
notion_info_list = args['info_list']['notion_info_list']
indexing_runner = IndexingRunner() for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
try: for page in notion_info['pages']:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, extract_setting = ExtractSetting(
args['info_list']['notion_info_list'], datasource_type="notion_import",
args['process_rule'], args['doc_form'], notion_info={
args['doc_language'], args['dataset_id'], "notion_workspace_id": workspace_id,
args['indexing_technique']) "notion_obj_id": page['page_id'],
except LLMBadRequestError: "notion_page_type": page['type'],
raise ProviderNotInitializeError( "tenant_id": current_user.current_tenant_id
"No Embedding Model available. Please configure a valid provider " },
"in the Settings -> Model Provider.") document_model=args['doc_form']
except ProviderTokenNotInitError as ex: )
raise ProviderNotInitializeError(ex.description) extract_settings.append(extract_setting)
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response, 200 return response, 200
...@@ -508,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') ...@@ -508,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
...@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner ...@@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.document_fields import ( from fields.document_fields import (
...@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource): ...@@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource):
req_data = request.args req_data = request.args
document_id = req_data.get('document_id') document_id = req_data.get('document_id')
# get default rules # get default rules
mode = DocumentService.DEFAULT_RULES['mode'] mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules'] rules = DocumentService.DEFAULT_RULES['rules']
...@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource): ...@@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource):
if not file: if not file:
raise NotFound('File not found.') raise NotFound('File not found.')
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file,
document_model=document.doc_form
)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
data_process_rule_dict, None, data_process_rule_dict, document.doc_form,
'English', dataset_id) 'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
...@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict()
info_list = [] info_list = []
extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in ['completed', 'error']: if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
...@@ -424,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): ...@@ -424,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
} }
info_list.append(notion_info) info_list.append(notion_info)
if dataset.data_source_type == 'upload_file': if document.data_source_type == 'upload_file':
file_details = db.session.query(UploadFile).filter( file_id = data_source_info['upload_file_id']
UploadFile.tenant_id == current_user.current_tenant_id, file_detail = db.session.query(UploadFile).filter(
UploadFile.id.in_(info_list) UploadFile.tenant_id == current_user.current_tenant_id,
).all() UploadFile.id == file_id
).first()
if file_details is None: if file_detail is None:
raise NotFound("File not found.") raise NotFound("File not found.")
indexing_runner = IndexingRunner() extract_setting = ExtractSetting(
try: datasource_type="upload_file",
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, upload_file=file_detail,
data_process_rule_dict, None, document_model=document.doc_form
'English', dataset_id) )
except LLMBadRequestError: extract_settings.append(extract_setting)
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " elif document.data_source_type == 'notion_import':
"in the Settings -> Model Provider.") extract_setting = ExtractSetting(
except ProviderTokenNotInitError as ex: datasource_type="notion_import",
raise ProviderNotInitializeError(ex.description) notion_info={
elif dataset.data_source_type == 'notion_import': "notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
info_list, data_process_rule_dict, document.doc_form,
data_process_rule_dict, 'English', dataset_id)
None, 'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.") "in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response return response
......
from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment from models.dataset import DatasetQuery, DocumentSegment
from models.model import DatasetRetrieverResource from models.model import DatasetRetrieverResource
......
import tempfile
from pathlib import Path
from typing import Optional, Union
import requests
from flask import current_app
from langchain.document_loaders import Docx2txtLoader, TextLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
return cls.load_from_file(file_path, return_text)
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[list[Document], str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
else MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
else TextLoader(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
import logging
from typing import Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> list[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from langchain.schema import Document
from sqlalchemy import func from sqlalchemy import func
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
......
...@@ -3,12 +3,12 @@ import logging ...@@ -3,12 +3,12 @@ import logging
from typing import Optional, cast from typing import Optional, cast
import numpy as np import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.entity.embedding import Embeddings
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs import helper from libs import helper
......
import logging import logging
from typing import Optional from typing import Optional
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.index.vector_index.vector_index import VectorIndex from core.rag.datasource.vdb.vector_factory import Vector
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
...@@ -45,17 +40,6 @@ class AnnotationReplyFeature: ...@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
embedding_provider_name = collection_binding_detail.provider_name embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name embedding_model_name = collection_binding_detail.model_name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_record.tenant_id,
provider=embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name
)
# get embedding model
embeddings = CacheEmbedding(model_instance)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_provider_name,
embedding_model_name, embedding_model_name,
...@@ -71,22 +55,14 @@ class AnnotationReplyFeature: ...@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
collection_binding_id=dataset_collection_binding.id collection_binding_id=dataset_collection_binding.id
) )
vector_index = VectorIndex( vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
documents = vector_index.search( documents = vector.search_by_vector(
query=query, query=query,
search_type='similarity_score_threshold', top_k=1,
search_kwargs={ score_threshold=score_threshold,
'k': 1, filter={
'score_threshold': score_threshold, 'group_id': [dataset.id]
'filter': {
'group_id': [dataset.id]
}
} }
) )
......
...@@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner): ...@@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner):
for message in messages: for message in messages:
result.append(UserPromptMessage(content=message.query)) result.append(UserPromptMessage(content=message.query))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
for agent_thought in agent_thoughts: if agent_thoughts:
tools = agent_thought.tool for agent_thought in agent_thoughts:
if tools: tools = agent_thought.tool
tools = tools.split(';') if tools:
tool_calls: list[AssistantPromptMessage.ToolCall] = [] tools = tools.split(';')
tool_call_response: list[ToolPromptMessage] = [] tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_inputs = json.loads(agent_thought.tool_input) tool_call_response: list[ToolPromptMessage] = []
for tool in tools: tool_inputs = json.loads(agent_thought.tool_input)
# generate a uuid for tool call for tool in tools:
tool_call_id = str(uuid.uuid4()) # generate a uuid for tool call
tool_calls.append(AssistantPromptMessage.ToolCall( tool_call_id = str(uuid.uuid4())
id=tool_call_id, tool_calls.append(AssistantPromptMessage.ToolCall(
type='function', id=tool_call_id,
function=AssistantPromptMessage.ToolCall.ToolCallFunction( type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool, name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})), tool_call_id=tool_call_id,
) ))
))
tool_call_response.append(ToolPromptMessage( result.extend([
content=agent_thought.observation, AssistantPromptMessage(
name=tool, content=agent_thought.thought,
tool_call_id=tool_call_id, tool_calls=tool_calls,
)) ),
*tool_call_response
result.extend([ ])
AssistantPromptMessage( if not tools:
content=agent_thought.thought, result.append(AssistantPromptMessage(content=agent_thought.thought))
tool_calls=tool_calls, else:
), if message.answer:
*tool_call_response result.append(AssistantPromptMessage(content=message.answer))
])
return result return result
\ No newline at end of file
...@@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought='', thought='',
action_str='', action_str='',
observation='', observation='',
action=None action=None,
) )
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
...@@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
thought=message.content, thought=message.content,
action_str='', action_str='',
action=None, action=None,
observation=None observation=None,
) )
if message.tool_calls: if message.tool_calls:
try: try:
...@@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
if current_scratchpad: if current_scratchpad:
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
return agent_scratchpad return agent_scratchpad
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
...@@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ...@@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
prompt_message.content = system_message prompt_message.content = system_message
overridden = True overridden = True
break break
# convert tool prompt messages to user prompt messages
for idx, prompt_message in enumerate(prompt_messages):
if isinstance(prompt_message, ToolPromptMessage):
prompt_messages[idx] = UserPromptMessage(
content=prompt_message.content
)
if not overridden: if not overridden:
prompt_messages.insert(0, SystemPromptMessage( prompt_messages.insert(0, SystemPromptMessage(
......
...@@ -104,37 +104,17 @@ class HostingConfiguration: ...@@ -104,37 +104,17 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota( trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit, quota_limit=hosted_quota_limit,
restrict_models=[ restrict_models=trial_models
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
) )
quotas.append(trial_quota) quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota( paid_quota = PaidHostingQuota(
restrict_models=[ restrict_models=paid_models
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
) )
quotas.append(paid_quota) quotas.append(paid_quota)
...@@ -258,3 +238,11 @@ class HostingConfiguration: ...@@ -258,3 +238,11 @@ class HostingConfiguration:
return HostedModerationConfig( return HostedModerationConfig(
enabled=False enabled=False
) )
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
models_list = models_str.split(",") if models_str else []
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
model_name.strip()]
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
This diff is collapsed.
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from models.dataset import Dataset
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'milvus'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
connection_args=self._client_config.to_milvus_params(),
index_params=index_params
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content'
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
embedding_function=self._embeddings,
connection_args=self._client_config.to_milvus_params()
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_document_id(document_id)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_metadata_field(key, value)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_doc_ids(doc_ids)
vector_store.del_texts({
'filter': f' id in {ids}'
})
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
import os
from typing import Any, Optional, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
))
def delete_by_ids(self, ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
for node_id in ids:
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
return vector_store.similarity_search_by_bm25(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
],
), kwargs.get('top_k', 2))
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
self._attributes = attributes
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
attributes: list) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings,
attributes=attributes
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
),
embeddings=embeddings
)
elif vector_type == "milvus":
from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex
return MilvusVectorIndex(
dataset=dataset,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
from typing import Any, Optional, cast
import requests
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = self._attributes
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def delete_by_metadata_field(self, key: str, value: str):
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": [key],
"valueText": value
})
def delete_by_group_id(self, group_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
This diff is collapsed.
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md) ​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)
- 可选择的模型列表展示 - 可选择的模型列表展示
...@@ -86,4 +86,4 @@ Model Runtime 分三层: ...@@ -86,4 +86,4 @@ Model Runtime 分三层:
![Alt text](docs/zh_Hans/images/index/image-2.png) ![Alt text](docs/zh_Hans/images/index/image-2.png)
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) ### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
\ No newline at end of file
...@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { ...@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
'min': 1, 'min': 1,
'max': 2048, 'max': 2048,
'precision': 0, 'precision': 0,
},
DefaultParameterName.RESPONSE_FORMAT: {
'label': {
'en_US': 'Response Format',
'zh_Hans': '回复格式',
},
'type': 'string',
'help': {
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
},
'required': False,
'options': ['JSON', 'XML'],
} }
} }
\ No newline at end of file
...@@ -91,6 +91,7 @@ class DefaultParameterName(Enum): ...@@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
PRESENCE_PENALTY = "presence_penalty" PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty" FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens" MAX_TOKENS = "max_tokens"
RESPONSE_FORMAT = "response_format"
@classmethod @classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName': def value_of(cls, value: Any) -> 'DefaultParameterName':
......
...@@ -262,23 +262,23 @@ class AIModel(ABC): ...@@ -262,23 +262,23 @@ class AIModel(ABC):
try: try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max: if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max'] parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min: if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min'] parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.precision: if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default'] parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision: if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision'] parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required: if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required'] parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help: if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject( parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'], en_US=default_parameter_rule['help']['en_US'],
) )
if not parameter_rule.help.en_US: if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans: if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError: except ValueError:
pass pass
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '8.00' input: '8.00'
output: '24.00' output: '24.00'
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '8.00' input: '8.00'
output: '24.00' output: '24.00'
......
...@@ -26,6 +26,8 @@ parameter_rules: ...@@ -26,6 +26,8 @@ parameter_rules:
default: 4096 default: 4096
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '1.63' input: '1.63'
output: '5.51' output: '5.51'
......
...@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream ...@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params from anthropic.types import Completion, completion_create_params
from httpx import Timeout from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import ( ...@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class AnthropicLargeLanguageModel(LargeLanguageModel): <instructions>
{{instructions}}
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
...@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ...@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
self._transform_json_prompts(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(
content=f"```{response_format}\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
......
...@@ -27,6 +27,8 @@ parameter_rules: ...@@ -27,6 +27,8 @@ parameter_rules:
default: 2048 default: 2048
min: 1 min: 1
max: 2048 max: 2048
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.00' input: '0.00'
output: '0.00' output: '0.00'
......
...@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large ...@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
...@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ...@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.0005' input: '0.0005'
output: '0.0015' output: '0.0015'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.0015' input: '0.0015'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.001' input: '0.001'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 16385 max: 16385
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.003' input: '0.003'
output: '0.004' output: '0.004'
......
...@@ -24,6 +24,8 @@ parameter_rules: ...@@ -24,6 +24,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 16385 max: 16385
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.003' input: '0.003'
output: '0.004' output: '0.004'
......
...@@ -21,6 +21,8 @@ parameter_rules: ...@@ -21,6 +21,8 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.0015' input: '0.0015'
output: '0.002' output: '0.002'
......
...@@ -24,6 +24,18 @@ parameter_rules: ...@@ -24,6 +24,18 @@ parameter_rules:
default: 512 default: 512
min: 1 min: 1
max: 4096 max: 4096
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing: pricing:
input: '0.001' input: '0.001'
output: '0.002' output: '0.002'
......
...@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio ...@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI ...@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
""" """
...@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ...@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
user=user user=user
) )
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
# handle fine tune remote models
base_model = model
if model.startswith('ft:'):
base_model = model.split(':')[1]
# get model mode
model_mode = self.get_model_mode(base_model, credentials)
# transform response format
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
stop = stop or []
if model_mode == LLMMode.CHAT:
# chat model
self._transform_chat_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
else:
self._transform_completion_json_prompts(
model=base_model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def _transform_completion_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# override the last user message
user_message = None
for i in range(len(prompt_messages) - 1, -1, -1):
if isinstance(prompt_messages[i], UserPromptMessage):
user_message = prompt_messages[i]
break
if user_message:
if prompt_messages[i].content[-11:] == 'Assistant: ':
# now we are in the chat app, remove the last assistant message
prompt_messages[i].content = prompt_messages[i].content[:-11]
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
else:
prompt_messages[i] = UserPromptMessage(
content=OPENAI_BLOCK_MODE_PROMPT
.replace("{{instructions}}", user_message.content)
.replace("{{block}}", response_format)
)
prompt_messages[i].content += f"\n```{response_format}\n"
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
......
...@@ -13,6 +13,7 @@ from dashscope.common.error import ( ...@@ -13,6 +13,7 @@ from dashscope.common.error import (
) )
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
""" """
# invoke model # invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _code_block_mode_wrapper(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
-> LLMResult | Generator:
"""
Wrapper for code block mode
"""
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
code_block = model_parameters.get("response_format", "")
if not code_block:
return self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
model_parameters.pop("response_format")
stop = stop or []
stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block)
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=block_prompts
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
))
mode = self.get_model_mode(model, credentials)
if mode == LLMMode.CHAT:
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += f"\n```{code_block}\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content=f"```{code_block}\n"
))
else:
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
response = self._invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
)
if isinstance(response, Generator):
return self._code_block_mode_stream_processor_with_backtick(
model=model,
prompt_messages=prompt_messages,
input_generator=response
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
...@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
""" """
extra_model_kwargs = {} extra_model_kwargs = {}
if stop: if stop:
extra_model_kwargs['stop_sequences'] = stop extra_model_kwargs['stop'] = stop
# transform credentials to kwargs for model instance # transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
...@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
params = { params = {
'model': model, 'model': model,
**model_parameters, **model_parameters,
**credentials_kwargs **credentials_kwargs,
**extra_model_kwargs,
} }
mode = self.get_model_mode(model, credentials) mode = self.get_model_mode(model, credentials)
......
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -57,3 +57,5 @@ parameter_rules: ...@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
...@@ -56,6 +56,8 @@ parameter_rules: ...@@ -56,6 +56,8 @@ parameter_rules:
help: help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.02' input: '0.02'
output: '0.02' output: '0.02'
......
...@@ -57,6 +57,8 @@ parameter_rules: ...@@ -57,6 +57,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false required: false
- name: response_format
use_template: response_format
pricing: pricing:
input: '0.008' input: '0.008'
output: '0.008' output: '0.008'
......
...@@ -25,6 +25,8 @@ parameter_rules: ...@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search - name: disable_search
label: label:
zh_Hans: 禁用搜索 zh_Hans: 禁用搜索
......
...@@ -25,6 +25,8 @@ parameter_rules: ...@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search - name: disable_search
label: label:
zh_Hans: 禁用搜索 zh_Hans: 禁用搜索
......
...@@ -25,3 +25,5 @@ parameter_rules: ...@@ -25,3 +25,5 @@ parameter_rules:
use_template: presence_penalty use_template: presence_penalty
- name: frequency_penalty - name: frequency_penalty
use_template: frequency_penalty use_template: frequency_penalty
- name: response_format
use_template: response_format
...@@ -34,3 +34,5 @@ parameter_rules: ...@@ -34,3 +34,5 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。 zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search. en_US: Disable the model to perform external search.
required: false required: false
- name: response_format
use_template: response_format
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Optional, Union, cast
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
...@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( ...@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
RateLimitReachedError, RateLimitReachedError,
) )
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
class ErnieBotLarguageModel(LargeLanguageModel): <instructions>
{{instructions}}
</instructions>
You should also complete the text started with ``` but not tell ``` directly.
"""
class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
...@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel): ...@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel):
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
response_format = model_parameters['response_format']
stop = stop or []
self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format)
model_parameters.pop('response_format')
if stream:
return self._code_block_mode_stream_processor(
model=model,
prompt_messages=prompt_messages,
input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
)
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts to model prompts
"""
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ERNIE_BOT_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# add ```JSON\n to the last message
prompt_messages[-1].content += "\n```JSON\n{\n"
else:
# append a user message
prompt_messages.append(UserPromptMessage(
content="```JSON\n{\n"
))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int: tools: list[PromptMessageTool] | None = None) -> int:
# tools is not supported yet # tools is not supported yet
......
...@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp ...@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
And you should always end the block with a "```" to indicate the end of the JSON object.
<instructions>
{{instructions}}
</instructions>
```JSON"""
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
...@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model # invoke model
# stop = stop or []
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
# def _transform_json_prompts(self, model: str, credentials: dict,
# prompt_messages: list[PromptMessage], model_parameters: dict,
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
# stream: bool = True, user: str | None = None) \
# -> None:
# """
# Transform json prompts to model prompts
# """
# if "}\n\n" not in stop:
# stop.append("}\n\n")
# # check if there is a system message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# # override the system message
# prompt_messages[0] = SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
# )
# else:
# # insert the system message
# prompt_messages.insert(0, SystemPromptMessage(
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
# ))
# # check if the last message is a user message
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
# # add ```JSON\n to the last message
# prompt_messages[-1].content += "\n```JSON\n"
# else:
# # append a user message
# prompt_messages.append(UserPromptMessage(
# content="```JSON\n"
# ))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:
""" """
...@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
""" """
extra_model_kwargs = {} extra_model_kwargs = {}
if stop: if stop:
extra_model_kwargs['stop_sequences'] = stop extra_model_kwargs['stop'] = stop
client = ZhipuAI( client = ZhipuAI(
api_key=credentials_kwargs['api_key'] api_key=credentials_kwargs['api_key']
...@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ...@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
] ]
if stream: if stream:
response = client.chat.completions.create(stream=stream, **params) response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs)
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
response = client.chat.completions.create(**params) response = client.chat.completions.create(**params, **extra_model_kwargs)
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
def _handle_generate_response(self, model: str, def _handle_generate_response(self, model: str,
......
import re
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
# default clean
# remove invalid symbol
text = re.sub(r'<\|', '<', text)
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
# Unicode U+FFFE
text = re.sub('\uFFFE', '', text)
rules = process_rule['rules'] if process_rule else None
if 'pre_processing_rules' in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r'\n{3,}'
text = re.sub(pattern, '\n\n', text)
pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
text = re.sub(pattern, ' ', text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
text = re.sub(pattern, '', text)
# Remove URL
pattern = r'https?://[^\s]+'
text = re.sub(pattern, '', text)
return text
def filter_string(self, text):
return text
"""Abstract interface for document cleaner implementations."""
from abc import ABC, abstractmethod
class BaseCleaner(ABC):
"""Interface for clean chunk content.
"""
@abstractmethod
def clean(self, content: str):
raise NotImplementedError
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_extra_whitespace
# Returns "ITEM 1A: RISK FACTORS"
return clean_extra_whitespace(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
import re
from unstructured.cleaners.core import group_broken_paragraphs
para_split_re = re.compile(r"(\s*\n\s*){3}")
return group_broken_paragraphs(content, paragraph_split=para_split_re)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.core import clean_non_ascii_chars
# Returns "This text containsnon-ascii characters!"
return clean_non_ascii_chars(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
def clean(self, content) -> str:
"""Replaces unicode quote characters, such as the \x91 character in a string."""
from unstructured.cleaners.core import replace_unicode_quotes
return replace_unicode_quotes(content)
"""Abstract interface for document clean implementations."""
from core.rag.cleaner.cleaner_base import BaseCleaner
class UnstructuredTranslateTextCleaner(BaseCleaner):
def clean(self, content) -> str:
"""clean document content."""
from unstructured.cleaners.translate import translate_text
return translate_text(content)
from typing import Optional
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.models.document import Document
from core.rerank.rerank import RerankRunner
class DataPostProcessor:
"""Interface for data post-processing document.
"""
def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
self.reorder_runner = self._get_reorder_runner(reorder_enabled)
def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)
return documents
def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
if reranking_model:
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return None
return RerankRunner(rerank_model_instance)
return None
def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
if reorder_enabled:
return ReorderRunner()
return None
from core.rag.models.document import Document
class ReorderRunner:
def run(self, documents: list[Document]) -> list[Document]:
# Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list
odd_elements = documents[::2]
# Retrieve elements from even indices (1, 3, 5, etc.) of the documents list
even_elements = documents[1::2]
# Reverse the list of elements from even indices
even_elements_reversed = even_elements[::-1]
new_documents = odd_elements + even_elements_reversed
return new_documents
from abc import ABC, abstractmethod
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
raise NotImplementedError
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronous Embed query text."""
raise NotImplementedError
...@@ -2,11 +2,11 @@ import json ...@@ -2,11 +2,11 @@ import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any, Optional
from langchain.schema import BaseRetriever, Document from pydantic import BaseModel
from pydantic import BaseModel, Extra, Field
from core.index.base import BaseIndex from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
...@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel): ...@@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10 max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex): class Jieba(BaseKeyword):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): def __init__(self, dataset: Dataset):
super().__init__(dataset) super().__init__(dataset)
self._config = config self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {} keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts: for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
return self return self
...@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex): ...@@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex):
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
for text in texts: keywords_list = kwargs.get('keywords_list', None)
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
...@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex): ...@@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
def delete_by_metadata_field(self, key: str, value: str):
pass
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search( def search(
self, query: str, self, query: str,
**kwargs: Any **kwargs: Any
) -> list[Document]: ) -> list[Document]:
keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} k = kwargs.get('top_k', 4)
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
...@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex): ...@@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table) db.session.delete(dataset_keyword_table)
db.session.commit() db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table): def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = { keyword_table_dict = {
'__type__': 'keyword_table', '__type__': 'keyword_table',
...@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex): ...@@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex):
).first() ).first()
if document_segment: if document_segment:
document_segment.keywords = keywords document_segment.keywords = keywords
db.session.add(document_segment)
db.session.commit() db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: list[str]): def create_segment_keywords(self, node_id: str, keywords: list[str]):
...@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex): ...@@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex):
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> list[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> list[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder): class SetEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, set): if isinstance(obj, set):
......
...@@ -3,7 +3,7 @@ import re ...@@ -3,7 +3,7 @@ import re
import jieba import jieba
from jieba.analyse import default_tfidf from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler: class JiebaKeywordTableHandler:
......
...@@ -3,22 +3,17 @@ from __future__ import annotations ...@@ -3,22 +3,17 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from langchain.schema import BaseRetriever, Document from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset
class BaseIndex(ABC): class BaseKeyword(ABC):
def __init__(self, dataset: Dataset): def __init__(self, dataset: Dataset):
self.dataset = dataset self.dataset = dataset
@abstractmethod @abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -34,31 +29,18 @@ class BaseIndex(ABC): ...@@ -34,31 +29,18 @@ class BaseIndex(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_document_id(self, document_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError raise NotImplementedError
@abstractmethod def delete(self) -> None:
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def search( def search(
self, query: str, self, query: str,
**kwargs: Any **kwargs: Any
) -> list[Document]: ) -> list[Document]:
raise NotImplementedError raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts: for text in texts:
doc_id = text.metadata['doc_id'] doc_id = text.metadata['doc_id']
......
from typing import Any
from flask import current_app
from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from models.dataset import Dataset
class Keyword:
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword:
config = current_app.config
keyword_type = config.get('KEYWORD_STORE')
if not keyword_type:
raise ValueError("Keyword store must be specified.")
if keyword_type == "jieba":
return Jieba(
dataset=self._dataset
)
else:
raise ValueError(f"Keyword store {keyword_type} is not supported.")
def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs)
def add_texts(self, texts: list[Document], **kwargs):
self._keyword_processor.add_texts(texts, **kwargs)
def text_exists(self, id: str) -> bool:
return self._keyword_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._keyword_processor.delete_by_ids(ids)
def delete_by_document_id(self, document_id: str) -> None:
self._keyword_processor.delete_by_document_id(document_id)
def delete(self) -> None:
self._keyword_processor.delete()
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):
if self._keyword_processor is not None:
method = getattr(self._keyword_processor, name)
if callable(method):
return method
raise AttributeError(f"'Keyword' object has no attribute '{name}'")
import threading
from typing import Optional from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.model_manager import ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword
from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_factory import Vector
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
...@@ -25,48 +23,115 @@ default_retrieval_model = { ...@@ -25,48 +23,115 @@ default_retrieval_model = {
class RetrievalService: class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
threads = []
# retrieval_model source with keyword
if retrival_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'all_documents': all_documents
})
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrival_method': retrival_method
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrival_method': retrival_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
)
return all_documents
@classmethod
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
keyword = Keyword(
dataset=dataset
)
documents = keyword.search(
query,
top_k=top_k
)
all_documents.extend(documents)
@classmethod @classmethod
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, retrival_method: str):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id Dataset.id == dataset_id
).first() ).first()
vector_index = VectorIndex( vector = Vector(
dataset=dataset, dataset=dataset
config=current_app.config,
embeddings=embeddings
) )
documents = vector_index.search( documents = vector.search_by_vector(
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ top_k=top_k,
'k': top_k, score_threshold=score_threshold,
'score_threshold': score_threshold, filter={
'filter': { 'group_id': [dataset.id]
'group_id': [dataset.id]
}
} }
) )
if documents: if documents:
if reranking_model and search_method == 'semantic_search': if reranking_model and retrival_method == 'semantic_search':
try: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
model_manager = ModelManager() all_documents.extend(data_post_processor.invoke(
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
...@@ -78,38 +143,24 @@ class RetrievalService: ...@@ -78,38 +143,24 @@ class RetrievalService:
@classmethod @classmethod
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, retrival_method: str):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id Dataset.id == dataset_id
).first() ).first()
vector_index = VectorIndex( vector_processor = Vector(
dataset=dataset, dataset=dataset,
config=current_app.config,
embeddings=embeddings
) )
documents = vector_index.search_by_full_text_index( documents = vector_processor.search_by_full_text(
query, query,
search_type='similarity_score_threshold',
top_k=top_k top_k=top_k
) )
if documents: if documents:
if reranking_model and search_method == 'full_text_search': if reranking_model and retrival_method == 'full_text_search':
try: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
model_manager = ModelManager() all_documents.extend(data_post_processor.invoke(
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=reranking_model['reranking_provider_name'],
model_type=ModelType.RERANK,
model=reranking_model['reranking_model_name']
)
except InvokeAuthorizationError:
return
rerank_runner = RerankRunner(rerank_model_instance)
all_documents.extend(rerank_runner.run(
query=query, query=query,
documents=documents, documents=documents,
score_threshold=score_threshold, score_threshold=score_threshold,
......
from enum import Enum
class Field(Enum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
import logging
from typing import Any, Optional
from uuid import uuid4
from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._fields = []
def get_type(self) -> str:
return 'milvus'
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
metadatas = [d.metadata for d in texts]
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
if not utility.has_collection(self._collection_name, using=alias):
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
}
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
if result:
return [item["id"] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
from pymilvus import utility
utility.drop_collection(self._collection_name, None, using=alias)
def text_exists(self, id: str) -> bool:
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
) -> str:
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
self._fields.remove(Field.PRIMARY_KEY.value)
# Create the collection
collection_name = self._collection_name
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
return collection_name
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password)
return client
This diff is collapsed.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""Abstract interface for document loader implementations."""
import csv import csv
import logging
from typing import Optional from typing import Optional
from langchain.document_loaders import CSVLoader as LCCSVLoader from core.rag.extractor.extractor_base import BaseExtractor
from langchain.document_loaders.helpers import detect_file_encodings from core.rag.models.document import Document
from langchain.schema import Document
logger = logging.getLogger(__name__)
class CSVExtractor(BaseExtractor):
"""Load CSV files.
Args:
file_path: Path to the file to load.
"""
class CSVLoader(LCCSVLoader):
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
source_column: Optional[str] = None, source_column: Optional[str] = None,
csv_args: Optional[dict] = None, csv_args: Optional[dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
): ):
self.file_path = file_path """Initialize with file path."""
self._file_path = file_path
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
self.source_column = source_column self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {} self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> list[Document]: def extract(self) -> list[Document]:
"""Load data into document objects.""" """Load data into document objects."""
try: try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile: with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
docs = self._read_from_file(csvfile) docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
if self.autodetect_encoding: if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path) detected_encodings = detect_filze_encodings(self._file_path)
for encoding in detected_encodings: for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try: try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile) docs = self._read_from_file(csvfile)
break break
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
else: else:
raise RuntimeError(f"Error loading {self.file_path}") from e raise RuntimeError(f"Error loading {self._file_path}") from e
return docs return docs
def _read_from_file(self, csvfile): def _read_from_file(self, csvfile) -> list[Document]:
docs = [] docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader): for i, row in enumerate(csv_reader):
......
from enum import Enum
class DatasourceType(Enum):
FILE = "upload_file"
NOTION = "notion_import"
This diff is collapsed.
This diff is collapsed.
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
class BaseExtractor(ABC):
"""Interface for extract files.
"""
@abstractmethod
def extract(self):
raise NotImplementedError
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment