Unverified Commit 3241e401 authored by John Wang's avatar John Wang Committed by GitHub

feat: upgrade langchain (#430)

Co-authored-by: 's avatarjyong <718720800@qq.com>
parent 1dee5de9
...@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session ...@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
import flask_login import flask_login
from flask_cors import CORS from flask_cors import CORS
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \ from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage ext_database, ext_storage
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
...@@ -79,7 +79,6 @@ def initialize_extensions(app): ...@@ -79,7 +79,6 @@ def initialize_extensions(app):
ext_database.init_app(app) ext_database.init_app(app)
ext_migrate.init(app, db) ext_migrate.init(app, db)
ext_redis.init_app(app) ext_redis.init_app(app)
ext_vector_store.init_app(app)
ext_storage.init_app(app) ext_storage.init_app(app)
ext_celery.init_app(app) ext_celery.init_app(app)
ext_session.init_app(app) ext_session.init_app(app)
......
import datetime import datetime
import logging
import random import random
import string import string
import click import click
from flask import current_app from flask import current_app
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from libs.password import password_pattern, valid_password, hash_password from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate from libs.helper import email as email_validate
from extensions.ext_database import db from extensions.ext_database import db
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant from models.account import InvitationCode, Tenant
from models.dataset import Dataset
from models.model import Account from models.model import Account
import secrets import secrets
import base64 import base64
...@@ -159,8 +163,39 @@ def generate_upper_string(): ...@@ -159,8 +163,39 @@ def generate_upper_string():
return result return result
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
def recreate_all_dataset_indexes():
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
recreate_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
try:
click.echo('Recreating dataset index: {}'.format(dataset.id))
index = IndexBuilder.get_index(dataset, 'high_quality')
if index and index._is_origin():
index.recreate_dataset(dataset)
recreate_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes) app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)
...@@ -187,11 +187,13 @@ class Config: ...@@ -187,11 +187,13 @@ class Config:
# For temp use only # For temp use only
# set default LLM provider, default is 'openai', support `azure_openai` # set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting # notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
class CloudEditionConfig(Config): class CloudEditionConfig(Config):
......
...@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound ...@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.oauth_data_source import NotionOAuth
from models.dataset import Document from models.dataset import Document
from models.source import DataSourceBinding from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
...@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource): ...@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
).first() ).first()
if not data_source_binding: if not data_source_binding:
raise NotFound('Data source binding not found.') raise NotFound('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
if page_type == 'page': loader = NotionLoader(
page_content = reader.read_page(page_id) notion_access_token=data_source_binding.access_token,
elif page_type == 'database': notion_workspace_id=workspace_id,
page_content = reader.query_database_data(page_id) notion_obj_id=page_id,
else: notion_page_type=page_type
page_content = "" )
text_docs = loader.load()
return { return {
'content': page_content 'content': "\n".join([doc.page_content for doc in text_docs])
}, 200 }, 200
@setup_required @setup_required
......
...@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles ...@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError UnsupportedFileTypeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.index.readers.html_parser import HTMLParser from core.data_loader.file_extractor import FileExtractor
from core.index.readers.pdf_parser import PDFParser
from core.index.readers.xlsx_parser import XLSXParser
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
...@@ -123,31 +121,7 @@ class FilePreviewApi(Resource): ...@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if extension not in ALLOWED_EXTENSIONS: if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
with tempfile.TemporaryDirectory() as temp_dir: text = FileExtractor.load(upload_file, return_text=True)
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, filepath)
if extension == 'pdf':
parser = PDFParser({'upload_file': upload_file})
text = parser.parse_file(Path(filepath))
elif extension in ['html', 'htm']:
# Use BeautifulSoup to extract text
parser = HTMLParser()
text = parser.parse_file(Path(filepath))
elif extension == 'xlsx':
parser = XLSXParser()
text = parser.parse_file(filepath)
else:
# ['txt', 'markdown', 'md']
with open(filepath, "rb") as fp:
data = fp.read()
encoding = chardet.detect(data)['encoding']
if encoding:
text = data.decode(encoding=encoding).strip() if data else ''
else:
text = data.decode(encoding='utf-8').strip() if data else ''
text = text[0:PREVIEW_WORDS_LIMIT] if text else '' text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text} return {'content': text}
......
...@@ -32,8 +32,13 @@ class VersionApi(Resource): ...@@ -32,8 +32,13 @@ class VersionApi(Resource):
'current_version': args.get('current_version') 'current_version': args.get('current_version')
}) })
except Exception as error: except Exception as error:
logging.exception("Check update error.") logging.warning("Check update version error: {}.".format(str(error)))
raise InternalServerError() return {
'version': args.get('current_version'),
'release_date': '',
'release_notes': '',
'can_auto_update': False
}
content = json.loads(response.content) content = json.loads(response.content)
return { return {
......
...@@ -3,19 +3,11 @@ from typing import Optional ...@@ -3,19 +3,11 @@ from typing import Optional
import langchain import langchain
from flask import Flask from flask import Flask
from jieba.analyse import default_tfidf
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from llama_index import IndexStructType, QueryMode
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
from pydantic import BaseModel from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.index.keyword_table.stopwords import STOPWORDS
from core.prompt.prompt_template import OneLineFormatter from core.prompt.prompt_template import OneLineFormatter
from core.vector_store.vector_store import VectorStore
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
class HostedOpenAICredential(BaseModel): class HostedOpenAICredential(BaseModel):
...@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials() ...@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask): def init_app(app: Flask):
formatter = OneLineFormatter() formatter = OneLineFormatter()
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
default_tfidf.stop_words = STOPWORDS
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True langchain.verbose = True
set_handler(DifyStdOutCallbackHandler())
if app.config.get("OPENAI_API_KEY"): if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
...@@ -2,7 +2,7 @@ from typing import Optional ...@@ -2,7 +2,7 @@ from typing import Optional
from langchain import LLMChain from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
...@@ -16,23 +16,20 @@ class AgentBuilder: ...@@ -16,23 +16,20 @@ class AgentBuilder:
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler, dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm( llm = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name, model_name=agent_loop_gather_callback_handler.model_name,
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
callback_manager=llm_callback_manager callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
) )
tool_callback_manager = CallbackManager([
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
])
for tool in tools: for tool in tools:
tool.callback_manager = tool_callback_manager tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
]
prompt = cls.build_agent_prompt_template( prompt = cls.build_agent_prompt_template(
tools=tools, tools=tools,
...@@ -54,7 +51,7 @@ class AgentBuilder: ...@@ -54,7 +51,7 @@ class AgentBuilder:
tools=tools, tools=tools,
agent=agent, agent=agent,
memory=memory, memory=memory,
callback_manager=agent_callback_manager, callbacks=agent_callback_manager,
max_iterations=6, max_iterations=6,
early_stopping_method="generate", early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
......
...@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask ...@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
class AgentLoopGatherCallbackHandler(BaseCallbackHandler): class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
...@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completion = response.generations[0][0].text self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
...@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
...@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end.""" """Run on agent end."""
# Final Answer # Final Answer
......
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
...@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask ...@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler): class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
...@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ...@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Do nothing.""" """Do nothing."""
logging.error(error) logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass
from llama_index import Response from typing import List
from langchain.schema import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
class IndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
def __init__(self) -> None:
self._response = None
@property
def response(self) -> Response:
return self._response
def on_tool_end(self, response: Response) -> None:
"""Handle tool end."""
self._response = response
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None: def __init__(self, dataset_id: str) -> None:
super().__init__()
self.dataset_id = dataset_id self.dataset_id = dataset_id
def on_tool_end(self, response: Response) -> None: def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end.""" """Handle tool end."""
for node in response.source_nodes: for document in documents:
index_node_id = node.node.doc_id doc_id = document.metadata['doc_id']
# add hit count to document segment # add hit count to document segment
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id, DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == index_node_id DocumentSegment.index_node_id == doc_id
).update( ).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False synchronize_session=False
......
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
...@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI ...@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler): class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
...@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Whether to call verbose callbacks even if verbose is False.""" """Whether to call verbose callbacks even if verbose is False."""
return True return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
role = 'user'
elif message.type == 'ai':
role = 'assistant'
else:
role = 'system'
real_prompts.append({
"role": role,
"text": message.content
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
self.start_at = time.perf_counter() self.start_at = time.perf_counter()
if 'Chat' in serialized['name']: self.llm_message.prompt = [{
real_prompts = [] "role": 'user',
messages = [] "text": prompts[0]
for prompt in prompts: }]
role, content = prompt.split(': ', maxsplit=1)
if role == 'human':
role = 'user'
message = HumanMessage(content=content)
elif role == 'ai':
role = 'assistant'
message = AIMessage(content=content)
else:
message = SystemMessage(content=content)
real_prompt = {
"role": role,
"text": content
}
real_prompts.append(real_prompt)
messages.append(message)
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
else:
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter() end_at = time.perf_counter()
...@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else: else:
logging.error(error) logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
pass
import logging import logging
import time import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult from core.callback_handler.entity.chain_result import ChainResult
...@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask ...@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
class MainChainGatherCallbackHandler(BaseCallbackHandler): class MainChainGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
...@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
if not self._current_chain_result: if not self._current_chain_result:
self._current_chain_result = ChainResult( chain_type = serialized['id'][-1]
type=serialized['name'], if chain_type:
prompt=inputs, self._current_chain_result = ChainResult(
started_at=time.perf_counter() type=chain_type,
) prompt=inputs,
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) started_at=time.perf_counter()
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message )
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
...@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
logging.error(error) logging.error(error)
self.clear_chain_results() self.clear_chain_results()
\ No newline at end of file
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.error(error)
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass
import os
import sys import sys
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
class DifyStdOutCallbackHandler(BaseCallbackHandler): class DifyStdOutCallbackHandler(BaseCallbackHandler):
...@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler.""" """Initialize callback handler."""
self.color = color self.color = color
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
for sub_messages in messages:
for sub_message in sub_messages:
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
"""Print out the prompts.""" """Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue') print_text("\n[on_llm_start]\n", color='blue')
print_text(prompts[0] + "\n", color='blue')
if 'Chat' in serialized['name']:
for prompt in prompts:
print_text(prompt + "\n", color='blue')
else:
print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing.""" """Do nothing."""
...@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
class_name = serialized["name"] chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink') print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
...@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent end.""" """Run on agent end."""
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming.""" """Callback handler for streaming. Only works with LLMs that support streaming."""
......
from typing import Optional from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain from core.chain.tool_chain import ToolChain
...@@ -14,7 +12,7 @@ class ChainBuilder: ...@@ -14,7 +12,7 @@ class ChainBuilder:
tool=tool, tool=tool,
input_key=kwargs.get('input_key', 'input'), input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'), output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) callbacks=[DifyStdOutCallbackHandler()]
) )
@classmethod @classmethod
...@@ -27,7 +25,7 @@ class ChainBuilder: ...@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words=sensitive_words.split(","), sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''), canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output", output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]), callbacks=[DifyStdOutCallbackHandler()],
**kwargs **kwargs
) )
......
"""Base classes for LLM-powered router chains.""" """Base classes for LLM-powered router chains."""
from __future__ import annotations from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from pydantic import root_validator from pydantic import root_validator
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown from libs.json_in_md_parser import parse_and_check_json_markdown
...@@ -51,8 +52,9 @@ class LLMRouterChain(Chain): ...@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise ValueError raise ValueError
def _call( def _call(
self, self,
inputs: Dict[str, Any] inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
output = cast( output = cast(
Dict[str, Any], Dict[str, Any],
......
from typing import Optional, List from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain from langchain.chains import SequentialChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder from core.chain.chain_builder import ChainBuilder
...@@ -18,6 +16,7 @@ from models.dataset import Dataset ...@@ -18,6 +16,7 @@ from models.dataset import Dataset
class MainChainBuilder: class MainChainBuilder:
@classmethod @classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
first_input_key = "input" first_input_key = "input"
final_output_key = "output" final_output_key = "output"
...@@ -30,6 +29,7 @@ class MainChainBuilder: ...@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains, chains_output_key = cls.get_agent_chains( tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id, tenant_id=tenant_id,
agent_mode=agent_mode, agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory, memory=memory,
conversation_message_task=conversation_message_task conversation_message_task=conversation_message_task
) )
...@@ -42,9 +42,8 @@ class MainChainBuilder: ...@@ -42,9 +42,8 @@ class MainChainBuilder:
return None return None
for chain in chains: for chain in chains:
# do not add handler into singleton callback manager chain = cast(Chain, chain)
if not isinstance(chain.callback_manager, SharedCallbackManager): chain.callbacks.append(chain_callback_handler)
chain.callback_manager.add_handler(chain_callback_handler)
# build main chain # build main chain
overall_chain = SequentialChain( overall_chain = SequentialChain(
...@@ -57,7 +56,9 @@ class MainChainBuilder: ...@@ -57,7 +56,9 @@ class MainChainBuilder:
return overall_chain return overall_chain
@classmethod @classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
# agent mode # agent mode
chains = [] chains = []
...@@ -93,7 +94,8 @@ class MainChainBuilder: ...@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id=tenant_id, tenant_id=tenant_id,
datasets=datasets, datasets=datasets,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
) )
chains.append(multi_dataset_router_chain) chains.append(multi_dataset_router_chain)
......
import math
from typing import Mapping, List, Dict, Any, Optional from typing import Mapping, List, Dict, Any, Optional
from langchain import LLMChain, PromptTemplate, ConversationChain from langchain import PromptTemplate
from langchain.callbacks import CallbackManager from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
...@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan ...@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_tool_builder import DatasetToolBuilder from core.tool.dataset_index_tool import DatasetTool
from core.tool.llama_index_tool import EnhanceLlamaIndexTool from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Dataset
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """ MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \ Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \ the input. You will be given the names of the available prompts and a description of \
...@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain): ...@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
router_chain: LLMRouterChain router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it.""" """Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, EnhanceLlamaIndexTool] dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to.""" """Map of name to candidate chains that inputs can be routed to."""
class Config: class Config:
...@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain): ...@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
tenant_id: str, tenant_id: str,
datasets: List[Dataset], datasets: List[Dataset],
conversation_message_task: ConversationMessageTask, conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any, **kwargs: Any,
): ):
"""Convenience constructor for instantiating from destination prompts.""" """Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm( llm = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
callback_manager=llm_callback_manager callbacks=[DifyStdOutCallbackHandler()]
) )
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name)) else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets] for d in datasets]
destinations_str = "\n".join(destinations) destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str destinations=destinations_str
) )
router_prompt = PromptTemplate( router_prompt = PromptTemplate(
template=router_template, template=router_template,
input_variables=["input"], input_variables=["input"],
output_parser=RouterOutputParser(), output_parser=RouterOutputParser(),
) )
router_chain = LLMRouterChain.from_llm(llm, router_prompt) router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {} dataset_tools = {}
for dataset in datasets: for dataset in datasets:
dataset_tool = DatasetToolBuilder.build_dataset_tool( # fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset, dataset=dataset,
response_mode='no_synthesizer', # "compact" callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
) )
if dataset_tool: dataset_tools[str(dataset.id)] = dataset_tool
dataset_tools[dataset.id] = dataset_tool
return cls( return cls(
router_chain=router_chain, router_chain=router_chain,
...@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain): ...@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
**kwargs, **kwargs,
) )
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call( def _call(
self, self,
inputs: Dict[str, Any] inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if len(self.dataset_tools) == 0: if len(self.dataset_tools) == 0:
return {"text": ''} return {"text": ''}
......
from typing import List, Dict from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
...@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain): ...@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return self.canned_response return self.canned_response
return text return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key] text = inputs[self.input_key]
output = self._check_sensitive_word(text) output = self._check_sensitive_word(text)
return {self.output_key: output} return {self.output_key: output}
from typing import List, Dict from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.tools import BaseTool from langchain.tools import BaseTool
...@@ -30,12 +31,20 @@ class ToolChain(Chain): ...@@ -30,12 +31,20 @@ class ToolChain(Chain):
""" """
return [self.output_key] return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key] input = inputs[self.input_key]
output = self.tool.run(input, self.verbose) output = self.tool.run(input, self.verbose)
return {self.output_key: output} return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
input = inputs[self.input_key] input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose) output = await self.tool.arun(input, self.verbose)
......
import logging import logging
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder from core.chain.main_chain_builder import MainChainBuilder
...@@ -34,8 +35,6 @@ class Completion: ...@@ -34,8 +35,6 @@ class Completion:
""" """
errors: ProviderTokenNotInitError errors: ProviderTokenNotInitError
""" """
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
memory = None memory = None
if conversation: if conversation:
# get memory of conversation (read-only) # get memory of conversation (read-only)
...@@ -48,6 +47,14 @@ class Completion: ...@@ -48,6 +47,14 @@ class Completion:
inputs = conversation.inputs inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
conversation_message_task = ConversationMessageTask( conversation_message_task = ConversationMessageTask(
task_id=task_id, task_id=task_id,
app=app, app=app,
...@@ -64,6 +71,7 @@ class Completion: ...@@ -64,6 +71,7 @@ class Completion:
main_chain = MainChainBuilder.to_langchain_components( main_chain = MainChainBuilder.to_langchain_components(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict, agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task conversation_message_task=conversation_message_task
) )
...@@ -115,7 +123,7 @@ class Completion: ...@@ -115,7 +123,7 @@ class Completion:
memory=memory memory=memory
) )
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task) final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=final_llm, final_llm=final_llm,
...@@ -247,16 +255,14 @@ And answer according to the language of the user's question. ...@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:'] return messages, ['\nHuman:']
@classmethod @classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager: conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming: if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else: else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()] return [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
...@@ -293,7 +299,8 @@ And answer according to the language of the user's question. ...@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
return memory return memory
@classmethod @classmethod
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str): def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model( llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id, tenant_id=tenant_id,
model=app_model_config.model_dict model=app_model_config.model_dict
...@@ -302,8 +309,26 @@ And answer according to the language of the user's question. ...@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens max_tokens = llm.max_tokens
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0: # get prompt without memory and context
raise LLMBadRequestError("Query is too long") prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod @classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
...@@ -360,7 +385,7 @@ And answer according to the language of the user's question. ...@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
streaming=streaming streaming=streaming
) )
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task) llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=llm, final_llm=llm,
......
...@@ -293,12 +293,12 @@ class PubHandler: ...@@ -293,12 +293,12 @@ class PubHandler:
if not user: if not user:
raise ValueError("user is required") raise ValueError("user is required")
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result:{}-{}".format(user_str, task_id) return "generate_result:{}-{}".format(user_str, task_id)
@classmethod @classmethod
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result_stopped:{}-{}".format(user_str, task_id) return "generate_result_stopped:{}-{}".format(user_str, task_id)
def pub_text(self, text: str): def pub_text(self, text: str):
...@@ -306,10 +306,10 @@ class PubHandler: ...@@ -306,10 +306,10 @@ class PubHandler:
'event': 'message', 'event': 'message',
'data': { 'data': {
'task_id': self._task_id, 'task_id': self._task_id,
'message_id': self._message.id, 'message_id': str(self._message.id),
'text': text, 'text': text,
'mode': self._conversation.mode, 'mode': self._conversation.mode,
'conversation_id': self._conversation.id 'conversation_id': str(self._conversation.id)
} }
} }
......
import tempfile
from pathlib import Path
from typing import List, Union
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from extensions.ext_storage import storage
from models.model import UploadFile
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
delimiter = '\n'
if input_file.suffix == '.xlsx':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
import logging
from typing import Optional, Dict, List
from langchain.document_loaders import CSVLoader as LCCSVLoader
from langchain.document_loaders.helpers import detect_file_encodings
from models.dataset import Document
logger = logging.getLogger(__name__)
class CSVLoader(LCCSVLoader):
def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
self.file_path = file_path
self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
"""Load data into document objects."""
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def _read_from_file(self, csvfile):
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
import json
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from openpyxl.reader.excel import load_workbook
logger = logging.getLogger(__name__)
class ExcelLoader(BaseLoader):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
data = []
keys = []
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, row))
row_dict = {k: v for k, v in row_dict.items() if v}
data.append(json.dumps(row_dict, ensure_ascii=False))
return [Document(page_content='\n\n'.join(data))]
import logging
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class HTMLLoader(BaseLoader):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
"""Markdown parser. import logging
import re
from typing import Optional, List, Tuple, cast
Contains parser for md files. from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class MarkdownLoader(BaseLoader):
"""Load md files.
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from llama_index.readers.file.base_parser import BaseParser Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
class MarkdownParser(BaseParser): remove_images: Whether to remove images from the text.
"""Markdown parser.
Extract text from markdown files. encoding: File encoding to use. If `None`, the file will be loaded
Returns dictionary with keys as headers and values as the text between headers. with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
""" """
def __init__( def __init__(
self, self,
*args: Any, file_path: str,
remove_hyperlinks: bool = True, remove_hyperlinks: bool = True,
remove_images: bool = True, remove_images: bool = True,
**kwargs: Any, encoding: Optional[str] = None,
) -> None: autodetect_encoding: bool = True,
"""Init params.""" ):
super().__init__(*args, **kwargs) """Initialize with file path."""
self._file_path = file_path
self._remove_hyperlinks = remove_hyperlinks self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images self._remove_images = remove_images
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
tups = self.parse_tups(self._file_path)
documents = []
for header, value in tups:
value = value.strip()
if header is None:
documents.append(Document(page_content=value))
else:
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
return documents
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary. """Convert a markdown file to a dictionary.
...@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser): ...@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
content = re.sub(pattern, r"\1", content) content = re.sub(pattern, r"\1", content)
return content return content
def _init_parser(self) -> Dict: def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
"""Initialize the parser with the config."""
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples.""" """Parse file into tuples."""
with open(filepath, "r", encoding="utf-8") as f: content = ""
content = f.read() try:
with open(filepath, "r", encoding=self._encoding) as f:
content = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {filepath}") from e
except Exception as e:
raise RuntimeError(f"Error loading {filepath}") from e
if self._remove_hyperlinks: if self._remove_hyperlinks:
content = self.remove_hyperlinks(content) content = self.remove_hyperlinks(content)
if self._remove_images: if self._remove_images:
content = self.remove_images(content) content = self.remove_images(content)
markdown_tups = self.markdown_to_tups(content)
return markdown_tups
def parse_file( return self.markdown_to_tups(content)
self, filepath: Path, errors: str = "ignore"
) -> Union[str, List[str]]:
"""Parse file into string."""
tups = self.parse_tups(filepath, errors=errors)
results = []
# TODO: don't include headers right now
for header, value in tups:
if header is None:
results.append(value)
else:
results.append(f"\n\n{header}\n{value}")
return results
"""Notion reader."""
import json import json
import logging import logging
import os from typing import List, Dict, Any, Optional
from datetime import datetime
from typing import Any, Dict, List, Optional
import requests # type: ignore import requests
from flask import current_app
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from llama_index.readers.base import BaseReader from extensions.ext_database import db
from llama_index.readers.schema.base import Document from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding
logger = logging.getLogger(__name__)
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search" SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
logger = logging.getLogger(__name__)
# TODO: Notion DB reader coming soon!
class NotionPageReader(BaseReader):
"""Notion Page reader.
Reads a set of Notion pages. class NotionLoader(BaseLoader):
def __init__(
Args: self,
integration_token (str): Notion integration token. notion_access_token: str,
notion_workspace_id: str,
""" notion_obj_id: str,
notion_page_type: str,
def __init__(self, integration_token: Optional[str] = None) -> None: document_model: Optional[DocumentModel] = None
"""Initialize with parameters.""" ):
if integration_token is None: self._document_model = document_model
integration_token = os.getenv(INTEGRATION_TOKEN_NAME) self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type
self._notion_access_token = notion_access_token
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
if integration_token is None: if integration_token is None:
raise ValueError( raise ValueError(
"Must specify `integration_token` or set environment " "Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`." "variable `NOTION_INTEGRATION_TOKEN`."
) )
self.token = integration_token
self.headers = {
"Authorization": "Bearer " + self.token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
}
def _read_block(self, block_id: str, num_tabs: int = 0) -> str: self._notion_access_token = integration_token
"""Read a block."""
done = False @classmethod
def from_document(cls, document_model: DocumentModel):
data_source_info = document_model.data_source_info_dict
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
notion_workspace_id = data_source_info['notion_workspace_id']
notion_obj_id = data_source_info['notion_page_id']
notion_page_type = data_source_info['type']
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
return cls(
notion_access_token=notion_access_token,
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_obj_id,
notion_page_type=notion_page_type,
document_model=document_model
)
def load(self) -> List[Document]:
self.update_last_edited_time(
self._document_model
)
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
return text_docs
def _load_data_as_documents(
self, notion_obj_id: str, notion_page_type: str
) -> List[Document]:
docs = []
if notion_page_type == 'database':
# get all the pages in the database
page_text = self._get_notion_database_data(notion_obj_id)
docs.append(Document(page_content=page_text))
elif notion_page_type == 'page':
page_text_list = self._get_notion_block_data(notion_obj_id)
for page_text in page_text_list:
docs.append(Document(page_content=page_text))
else:
raise ValueError("notion page type not supported")
return docs
def _get_notion_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> str:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict,
)
data = res.json()
database_content_list = []
if 'results' not in data or data["results"] is None:
return ""
for result in data["results"]:
properties = result['properties']
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
else:
value = ''
elif type == 'select' or type == 'status':
if property_value[type]:
value = property_value[type]['name']
else:
value = ''
else:
value = property_value[type]
data[property_name] = value
database_content_list.append(json.dumps(data, ensure_ascii=False))
return "\n\n".join(database_content_list)
def _get_notion_block_data(self, page_id: str) -> List[str]:
result_lines_arr = [] result_lines_arr = []
cur_block_id = block_id cur_block_id = page_id
while not done: while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {} query_dict: Dict[str, Any] = {}
res = requests.request( res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict "GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
) )
data = res.json() data = res.json()
if 'results' not in data or data["results"] is None: # current block's heading
done = True
break
heading = '' heading = ''
for result in data["results"]: for result in data["results"]:
result_type = result["type"] result_type = result["type"]
...@@ -71,6 +165,7 @@ class NotionPageReader(BaseReader): ...@@ -71,6 +165,7 @@ class NotionPageReader(BaseReader):
if result_type == 'table': if result_type == 'table':
result_block_id = result["id"] result_block_id = result["id"]
text = self._read_table_rows(result_block_id) text = self._read_table_rows(result_block_id)
text += "\n\n"
result_lines_arr.append(text) result_lines_arr.append(text)
else: else:
if "rich_text" in result_obj: if "rich_text" in result_obj:
...@@ -78,91 +173,53 @@ class NotionPageReader(BaseReader): ...@@ -78,91 +173,53 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object # skip if doesn't have text object
if "text" in rich_text: if "text" in rich_text:
text = rich_text["text"]["content"] text = rich_text["text"]["content"]
prefix = "\t" * num_tabs cur_result_text_arr.append(text)
cur_result_text_arr.append(prefix + text)
if result_type in HEADING_TYPE: if result_type in HEADING_TYPE:
heading = text heading = text
result_block_id = result["id"] result_block_id = result["id"]
has_children = result["has_children"] has_children = result["has_children"]
block_type = result["type"] block_type = result["type"]
if has_children and block_type != 'child_page': if has_children and block_type != 'child_page':
children_text = self._read_block( children_text = self._read_block(
result_block_id, num_tabs=num_tabs + 1 result_block_id, num_tabs=1
) )
cur_result_text_arr.append(children_text) cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr) cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE: if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text) result_lines_arr.append(cur_result_text)
else: else:
result_lines_arr.append(f'{heading}\n{cur_result_text}') result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None: if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr)
return result_lines
def _read_table_rows(self, block_id: str) -> str:
"""Read table rows."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
)
data = res.json()
# get table headers text
table_header_cell_texts = []
tabel_header_cells = data["results"][0]['table_row']['cells']
for tabel_header_cell in tabel_header_cells:
if tabel_header_cell:
for table_header_cell_text in tabel_header_cell:
text = table_header_cell_text["text"]["content"]
table_header_cell_texts.append(text)
# get table columns text and format
results = data["results"]
for i in range(len(results)-1):
column_texts = []
tabel_column_cells = data["results"][i+1]['table_row']['cells']
for j in range(len(tabel_column_cells)):
if tabel_column_cells[j]:
for table_column_cell_text in tabel_column_cells[j]:
column_text = table_column_cell_text["text"]["content"]
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
cur_result_text = "\n".join(column_texts)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
break break
else: else:
cur_block_id = data["next_cursor"] cur_block_id = data["next_cursor"]
return result_lines_arr
result_lines = "\n".join(result_lines_arr) def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
return result_lines
def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
"""Read a block.""" """Read a block."""
done = False
result_lines_arr = [] result_lines_arr = []
cur_block_id = block_id cur_block_id = block_id
while not done: while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {} query_dict: Dict[str, Any] = {}
res = requests.request( res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict "GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
) )
data = res.json() data = res.json()
# current block's heading if 'results' not in data or data["results"] is None:
break
heading = '' heading = ''
for result in data["results"]: for result in data["results"]:
result_type = result["type"] result_type = result["type"]
...@@ -171,7 +228,6 @@ class NotionPageReader(BaseReader): ...@@ -171,7 +228,6 @@ class NotionPageReader(BaseReader):
if result_type == 'table': if result_type == 'table':
result_block_id = result["id"] result_block_id = result["id"]
text = self._read_table_rows(result_block_id) text = self._read_table_rows(result_block_id)
text += "\n\n"
result_lines_arr.append(text) result_lines_arr.append(text)
else: else:
if "rich_text" in result_obj: if "rich_text" in result_obj:
...@@ -179,10 +235,10 @@ class NotionPageReader(BaseReader): ...@@ -179,10 +235,10 @@ class NotionPageReader(BaseReader):
# skip if doesn't have text object # skip if doesn't have text object
if "text" in rich_text: if "text" in rich_text:
text = rich_text["text"]["content"] text = rich_text["text"]["content"]
cur_result_text_arr.append(text) prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
if result_type in HEADING_TYPE: if result_type in HEADING_TYPE:
heading = text heading = text
result_block_id = result["id"] result_block_id = result["id"]
has_children = result["has_children"] has_children = result["has_children"]
block_type = result["type"] block_type = result["type"]
...@@ -193,177 +249,121 @@ class NotionPageReader(BaseReader): ...@@ -193,177 +249,121 @@ class NotionPageReader(BaseReader):
cur_result_text_arr.append(children_text) cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr) cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE: if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text) result_lines_arr.append(cur_result_text)
else: else:
result_lines_arr.append(f'{heading}\n{cur_result_text}') result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None: if data["next_cursor"] is None:
done = True
break break
else: else:
cur_block_id = data["next_cursor"] cur_block_id = data["next_cursor"]
return result_lines_arr
def read_page(self, page_id: str) -> str:
"""Read a page."""
return self._read_block(page_id)
def read_page_as_documents(self, page_id: str) -> List[str]:
"""Read a page as documents."""
return self._read_parent_blocks(page_id)
def query_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> str:
"""Get all the pages from a Notion database."""
res = requests.post\
(
DATABASE_URL_TMPL.format(database_id=database_id),
headers=self.headers,
json=query_dict,
)
data = res.json()
database_content_list = []
if 'results' not in data or data["results"] is None:
return ""
for result in data["results"]:
properties = result['properties']
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
else:
value = ''
elif type == 'select' or type == 'status':
if property_value[type]:
value = property_value[type]['name']
else:
value = ''
else:
value = property_value[type]
data[property_name] = value
database_content_list.append(json.dumps(data, ensure_ascii=False))
return "\n\n".join(database_content_list)
def query_database(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> List[str]:
"""Get all the pages from a Notion database."""
res = requests.post\
(
DATABASE_URL_TMPL.format(database_id=database_id),
headers=self.headers,
json=query_dict,
)
data = res.json()
page_ids = []
for result in data["results"]:
page_id = result["id"]
page_ids.append(page_id)
return page_ids result_lines = "\n".join(result_lines_arr)
return result_lines
def search(self, query: str) -> List[str]: def _read_table_rows(self, block_id: str) -> str:
"""Search Notion page given a text query.""" """Read table rows."""
done = False done = False
next_cursor: Optional[str] = None result_lines_arr = []
page_ids = [] cur_block_id = block_id
while not done: while not done:
query_dict = { block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
"query": query, query_dict: Dict[str, Any] = {}
}
if next_cursor is not None: res = requests.request(
query_dict["start_cursor"] = next_cursor "GET",
res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict) block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json() data = res.json()
for result in data["results"]: # get table headers text
page_id = result["id"] table_header_cell_texts = []
page_ids.append(page_id) tabel_header_cells = data["results"][0]['table_row']['cells']
for tabel_header_cell in tabel_header_cells:
if tabel_header_cell:
for table_header_cell_text in tabel_header_cell:
text = table_header_cell_text["text"]["content"]
table_header_cell_texts.append(text)
# get table columns text and format
results = data["results"]
for i in range(len(results) - 1):
column_texts = []
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
for j in range(len(tabel_column_cells)):
if tabel_column_cells[j]:
for table_column_cell_text in tabel_column_cells[j]:
column_text = table_column_cell_text["text"]["content"]
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
cur_result_text = "\n".join(column_texts)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None: if data["next_cursor"] is None:
done = True done = True
break break
else: else:
next_cursor = data["next_cursor"] cur_block_id = data["next_cursor"]
return page_ids
def load_data( result_lines = "\n".join(result_lines_arr)
self, page_ids: List[str] = [], database_id: Optional[str] = None return result_lines
) -> List[Document]:
"""Load data from the input directory.
Args: def update_last_edited_time(self, document_model: DocumentModel):
page_ids (List[str]): List of page ids to load. if not document_model:
return
Returns: last_edited_time = self.get_notion_last_edited_time()
List[Document]: List of documents. data_source_info = document_model.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
DocumentModel.data_source_info: json.dumps(data_source_info)
}
""" DocumentModel.query.filter_by(id=document_model.id).update(update_params)
if not page_ids and not database_id: db.session.commit()
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text))
else:
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text))
return docs def get_notion_last_edited_time(self) -> str:
obj_id = self._notion_obj_id
def load_data_as_documents( page_type = self._notion_page_type
self, page_ids: List[str] = [], database_id: Optional[str] = None if page_type == 'database':
) -> List[Document]: retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_text = self.query_database_data(database_id)
docs.append(Document(page_text))
else: else:
for page_id in page_ids: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
page_text_list = self.read_page_as_documents(page_id)
for page_text in page_text_list:
docs.append(Document(page_text))
return docs
def get_page_last_edited_time(self, page_id: str) -> str:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
query_dict: Dict[str, Any] = {} query_dict: Dict[str, Any] = {}
res = requests.request( res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict "GET",
retrieve_page_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
) )
data = res.json()
return data["last_edited_time"]
def get_database_last_edited_time(self, database_id: str) -> str:
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict
)
data = res.json() data = res.json()
return data["last_edited_time"] return data["last_edited_time"]
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
)
).first()
if not data_source_binding:
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
f'and notion workspace {notion_workspace_id}')
if __name__ == "__main__": return data_source_binding.access_token
reader = NotionPageReader()
logger.info(reader.search("What I"))
import logging
from typing import List, Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> List[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
from typing import Any, Dict, Optional, Sequence from typing import Any, Dict, Optional, Sequence
import tiktoken from langchain.schema import Document
from llama_index.data_structs import Node
from llama_index.docstore.types import BaseDocumentStore
from llama_index.docstore.utils import json_to_doc
from llama_index.schema import BaseDocument
from sqlalchemy import func from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
...@@ -12,7 +8,7 @@ from extensions.ext_database import db ...@@ -12,7 +8,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore(BaseDocumentStore): class DatesetDocumentStore:
def __init__( def __init__(
self, self,
dataset: Dataset, dataset: Dataset,
...@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return self._embedding_model_name return self._embedding_model_name
@property @property
def docs(self) -> Dict[str, BaseDocument]: def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter( document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id DocumentSegment.dataset_id == self._dataset.id
).all() ).all()
...@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output = {} output = {}
for document_segment in document_segments: for document_segment in document_segments:
doc_id = document_segment.index_node_id doc_id = document_segment.index_node_id
result = self.segment_to_dict(document_segment) output[doc_id] = Document(
output[doc_id] = json_to_doc(result) page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
return output return output
def add_documents( def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True self, docs: Sequence[Document], allow_update: bool = True
) -> None: ) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id DocumentSegment.document == self._document_id
...@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position = 0 max_position = 0
for doc in docs: for doc in docs:
if doc.is_doc_id_none: if not isinstance(doc, Document):
raise ValueError("doc_id not set") raise ValueError("doc must be a Document")
if not isinstance(doc, Node): segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
raise ValueError("doc must be a Node")
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
# NOTE: doc could already exist in the store, but we overwrite it # NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document: if not allow_update and segment_document:
raise ValueError( raise ValueError(
f"doc_id {doc.get_doc_id()} already exists. " f"doc_id {doc.metadata['doc_id']} already exists. "
"Set allow_update to True to overwrite." "Set allow_update to True to overwrite."
) )
# calc embedding use tokens # calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text()) tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
if not segment_document: if not segment_document:
max_position += 1 max_position += 1
...@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id=self._dataset.tenant_id, tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id, dataset_id=self._dataset.id,
document_id=self._document_id, document_id=self._document_id,
index_node_id=doc.get_doc_id(), index_node_id=doc.metadata['doc_id'],
index_node_hash=doc.get_doc_hash(), index_node_hash=doc.metadata['doc_hash'],
position=max_position, position=max_position,
content=doc.get_text(), content=doc.page_content,
word_count=len(doc.get_text()), word_count=len(doc.page_content),
tokens=tokens, tokens=tokens,
created_by=self._user_id, created_by=self._user_id,
) )
db.session.add(segment_document) db.session.add(segment_document)
else: else:
segment_document.content = doc.get_text() segment_document.content = doc.page_content
segment_document.index_node_hash = doc.get_doc_hash() segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.get_text()) segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens segment_document.tokens = tokens
db.session.commit() db.session.commit()
...@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def get_document( def get_document(
self, doc_id: str, raise_error: bool = True self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]: ) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id) document_segment = self.get_document_segment(doc_id)
if document_segment is None: if document_segment is None:
...@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else: else:
return None return None
result = self.segment_to_dict(document_segment) return Document(
return json_to_doc(result) page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None: def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
document_segment = self.get_document_segment(doc_id) document_segment = self.get_document_segment(doc_id)
...@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return document_segment.index_node_hash return document_segment.index_node_hash
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
def get_document_segment(self, doc_id: str) -> DocumentSegment: def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter( document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.dataset_id == self._dataset.id,
...@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
).first() ).first()
return document_segment return document_segment
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
return {
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"text": segment.content,
"__type__": Node.get_type()
}
from typing import Any, Dict, Optional, Sequence
from llama_index.docstore.types import BaseDocumentStore
from llama_index.schema import BaseDocument
class EmptyDocumentStore(BaseDocumentStore):
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {}
@property
def docs(self) -> Dict[str, BaseDocument]:
return {}
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
pass
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return False
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
return None
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
pass
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
pass
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
return None
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
import logging
from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings = []
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
i += 1
text_embeddings.extend(embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to db')
return embedding_results
from typing import Optional, Any, List
import openai
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
_TEXT_MODE_MODEL_DICT
from tenacity import wait_random_exponential, retry, stop_after_attempt
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
text: str,
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[float]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding"
]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
class OpenAIEmbedding(BaseEmbedding):
def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
new_kwargs = {}
if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']
super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name
self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')
@handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(self._get_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(await self._aget_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
from __future__ import annotations
from abc import abstractmethod, ABC
from typing import List, Any
from langchain.schema import Document, BaseRetriever
from models.dataset import Dataset
class BaseIndex(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError
@abstractmethod
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError
@abstractmethod
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
\ No newline at end of file
from langchain.callbacks import CallbackManager
from llama_index import ServiceContext, PromptHelper, LLMPredictor
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.embedding.openai_embedding import OpenAIEmbedding
from core.llm.llm_builder import LLMBuilder
class IndexBuilder:
@classmethod
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
# set number of output tokens
num_output = 512
# only for verbose
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='text-davinci-003',
temperature=0,
max_tokens=num_output,
callback_manager=callback_manager,
)
llm_predictor = LLMPredictor(llm=llm)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper = PromptHelper(
max_input_size=3500,
num_output=num_output,
max_chunk_overlap=20
)
provider = LLMBuilder.get_default_provider(tenant_id)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002'
)
return ServiceContext.from_defaults(
llm_predictor=llm_predictor,
prompt_helper=prompt_helper,
embed_model=OpenAIEmbedding(**model_credentials),
)
@classmethod
def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='fake'
)
return ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
import re
from typing import (
Any,
Dict,
List,
Set,
Optional
)
import jieba.analyse
from core.index.keyword_table.stopwords import STOPWORDS
from llama_index.indices.query.base import IS
from llama_index import QueryMode
from llama_index.indices.base import QueryMap
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
def jieba_extract_keywords(
text_chunk: str,
max_keywords: Optional[int] = None,
expand_with_subtokens: bool = True,
) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text_chunk,
topK=max_keywords,
)
if expand_with_subtokens:
return set(expand_tokens_with_subtokens(keywords))
else:
return set(keywords)
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def _extract_keywords(self, text: str) -> Set[str]:
"""Extract keywords from text."""
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
@classmethod
def get_query_map(self) -> QueryMap:
"""Get query map."""
super_map = super().get_query_map()
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
return super_map
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete = {doc_id}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in self._index_struct.table.items():
if node_idxs_to_delete.intersection(node_idxs):
self._index_struct.table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not self._index_struct.table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del self._index_struct.table[keyword]
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)
import json
from typing import List, Optional
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
from llama_index.data_structs import KeywordTable, Node
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.registry import load_index_struct_from_dict
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.index_builder import IndexBuilder
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
index_struct = KeywordTable()
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node in nodes:
keywords = index._extract_keywords(node.get_text())
self.update_segment_keywords(node.doc_id, list(keywords))
index._index_struct.add_node(list(keywords), node)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
def del_nodes(self, node_ids: List[str]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node_id in node_ids:
index.delete(node_id)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
@property
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
docstore = DatesetDocumentStore(
dataset=self._dataset,
user_id=self._dataset.created_by,
embedding_model_name="text-embedding-ada-002"
)
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return None
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
def get_keyword_table(self):
dataset_keyword_table = self._dataset.dataset_keyword_table
if dataset_keyword_table:
return dataset_keyword_table
return None
def update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
import re
from typing import Set
import jieba
from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
\ No newline at end of file
import json
from collections import defaultdict
from typing import Any, List, Optional, Dict
from langchain.schema import Document, BaseRetriever
from pydantic import BaseModel, Field, Extra
from core.index.base import BaseIndex
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
super().__init__(dataset)
self._config = config
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
return documents
def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
}
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
else:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del keyword_table[keyword]
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
list(chunk_indices_count.keys()),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
from typing import (
Any,
Dict,
Optional, Sequence,
)
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from llama_index.types import RESPONSE_TEXT_TYPE
class EnhanceResponseSynthesizer(ResponseSynthesizer):
@classmethod
def from_args(
cls,
service_context: ServiceContext,
streaming: bool = False,
use_async: bool = False,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_mode: ResponseMode = ResponseMode.DEFAULT,
response_kwargs: Optional[Dict] = None,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
) -> "ResponseSynthesizer":
response_builder: Optional[BaseResponseBuilder] = None
if response_mode != ResponseMode.NO_TEXT:
if response_mode == 'no_synthesizer':
response_builder = NoSynthesizer(
service_context=service_context,
simple_template=simple_template,
streaming=streaming,
)
else:
response_builder = get_response_builder(
service_context,
text_qa_template,
refine_template,
simple_template,
response_mode,
use_async=use_async,
streaming=streaming,
)
return cls(response_builder, response_mode, response_kwargs, optimizer)
class NoSynthesizer(BaseResponseBuilder):
def __init__(
self,
service_context: ServiceContext,
simple_template: Optional[SimpleInputPrompt] = None,
streaming: bool = False,
) -> None:
super().__init__(service_context, streaming)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
\ No newline at end of file
from pathlib import Path
from typing import Dict
from bs4 import BeautifulSoup
from llama_index.readers.file.base_parser import BaseParser
class HTMLParser(BaseParser):
"""HTML parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
with open(file, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
from pathlib import Path
from typing import Dict
from flask import current_app
from llama_index.readers.file.base_parser import BaseParser
from pypdf import PdfReader
from extensions.ext_storage import storage
from models.model import UploadFile
class PDFParser(BaseParser):
"""PDF parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
if not current_app.config.get('PDF_PREVIEW', True):
return ''
plaintext_file_key = ''
plaintext_file_exists = False
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
upload_file: UploadFile = self._parser_config['upload_file']
if upload_file.hash:
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return text
except FileNotFoundError:
pass
text_list = []
with open(file, "rb") as fp:
# Create a PDF object
pdf = PdfReader(fp)
# Get the number of pages in the PDF document
num_pages = len(pdf.pages)
# Iterate over every page
for page in range(num_pages):
# Extract the text from the page
page_text = pdf.pages[page].extract_text()
text_list.append(page_text)
text = "\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return text
from pathlib import Path
import json
from typing import Dict
from openpyxl import load_workbook
from llama_index.readers.file.base_parser import BaseParser
from flask import current_app
class XLSXParser(BaseParser):
"""XLSX parser."""
def _init_parser(self) -> Dict:
"""Init parser"""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
data = []
keys = []
with open(file, "r") as fp:
wb = load_workbook(filename=file, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, row))
row_dict = {k: v for k, v in row_dict.items() if v}
data.append(json.dumps(row_dict, ensure_ascii=False))
return '\n\n'.join(data)
import json
import logging
from typing import List, Optional
from llama_index.data_structs import Node
from requests import ReadTimeout
from sqlalchemy.exc import IntegrityError
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from core.index.index_builder import IndexBuilder
from core.vector_store.base import BaseGPTVectorStoreIndex
from extensions.ext_vector_store import vector_store
from extensions.ext_database import db
from models.dataset import Dataset, Embedding
class VectorIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
if duplicate_check:
nodes = self._filter_duplicate_nodes(index, nodes)
embedding_queue_nodes = []
embedded_nodes = []
for node in nodes:
node_hash = node.doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
if embedding:
node.embedding = embedding.get_embedding()
embedded_nodes.append(node)
else:
embedding_queue_nodes.append(node)
if embedding_queue_nodes:
embedding_results = index._get_node_embedding_results(
embedding_queue_nodes,
set(),
)
# pre embed nodes for cached embedding
for embedding_result in embedding_results:
node = embedding_result.node
node.embedding = embedding_result.embedding
try:
embedding = Embedding(hash=node.doc_hash)
embedding.set_embedding(node.embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
embedded_nodes.append(node)
self.index_insert_nodes(index, embedded_nodes)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
index.insert_nodes(nodes)
def del_nodes(self, node_ids: List[str]):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
for node_id in node_ids:
self.index_delete_node(index, node_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
index.delete_node(node_id)
def del_doc(self, doc_id: str):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
self.index_delete_doc(index, doc_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
index.delete(doc_id)
@property
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
if not self._dataset.index_struct_dict:
return None
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
return vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
for node in nodes:
node_id = node.doc_id
exists_duplicate_node = index.exists_by_node_id(node_id)
if exists_duplicate_node:
nodes.remove(node)
return nodes
import json
import logging
from abc import abstractmethod
from typing import List, Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.info(f"Recreating dataset {dataset.id}")
try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
origin_index_struct = self.dataset.index_struct
self.dataset.index_struct = None
if documents:
try:
self.create(documents)
except Exception as e:
self.dataset.index_struct = origin_index_struct
raise e
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
import os
from typing import Optional, Any, List, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if class_prefix.startswith('Vector_'):
# original class_prefix
return True
return False
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError(f"Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
from typing import Optional, cast
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
import datetime import datetime
import json import json
import logging
import re import re
import tempfile
import time import time
from pathlib import Path import uuid
from typing import Optional, List from typing import Optional, List, cast
from flask import current_app
from flask_login import current_user from flask_login import current_user
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from llama_index import SimpleDirectoryReader from core.data_loader.file_extractor import FileExtractor
from llama_index.data_structs import Node from core.data_loader.loader.notion import NotionLoader
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.node_parser import SimpleNodeParser, NodeParser
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.file.markdown_parser import MarkdownParser
from core.data_source.notion import NotionPageReader
from core.index.readers.xlsx_parser import XLSXParser
from core.docstore.dataset_docstore import DatesetDocumentStore from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex from core.embedding.cached_embedding import CacheEmbedding
from core.index.readers.html_parser import HTMLParser from core.index.index import IndexBuilder
from core.index.readers.markdown_parser import MarkdownParser from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.readers.pdf_parser import PDFParser from core.index.vector_index.vector_index import VectorIndex
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.llm.error import ProviderTokenNotInitError
from core.index.vector_index import VectorIndex from core.llm.llm_builder import LLMBuilder
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule from libs import helper
from models.dataset import Document as DatasetDocument
from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding from models.source import DataSourceBinding
...@@ -40,135 +39,171 @@ class IndexingRunner: ...@@ -40,135 +39,171 @@ class IndexingRunner:
self.storage = storage self.storage = storage
self.embedding_model_name = embedding_model_name self.embedding_model_name = embedding_model_name
def run(self, documents: List[Document]): def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process.""" """Run the indexing process."""
for document in documents: for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get splitter
splitter = self._get_splitter(processing_rule)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset # get dataset
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=document.dataset_id id=dataset_document.dataset_id
).first() ).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=dataset_document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file # load file
text_docs = self._load_data(document) text_docs = self._load_data(dataset_document)
# get the process rule # get the process rule
processing_rule = db.session.query(DatasetProcessRule). \ processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first() first()
# get node parser for splitting # get splitter
node_parser = self._get_node_parser(processing_rule) splitter = self._get_splitter(processing_rule)
# split to nodes # split to documents
nodes = self._step_split( documents = self._step_split(
text_docs=text_docs, text_docs=text_docs,
node_parser=node_parser, splitter=splitter,
dataset=dataset, dataset=dataset,
document=document, dataset_document=dataset_document,
processing_rule=processing_rule processing_rule=processing_rule
) )
# build index # build index
self._build_index( self._build_index(
dataset=dataset, dataset=dataset,
document=document, dataset_document=dataset_document,
nodes=nodes documents=documents
) )
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, document: Document): def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is indexing."""
# get dataset try:
dataset = Dataset.query.filter_by( # get dataset
id=document.dataset_id dataset = Dataset.query.filter_by(
).first() id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes if not dataset:
nodes = self._step_split( raise ValueError("no dataset found")
text_docs=text_docs,
node_parser=node_parser,
dataset=dataset,
document=document,
processing_rule=processing_rule
)
# build index # get exist document_segment list and delete
self._build_index( document_segments = DocumentSegment.query.filter_by(
dataset=dataset, dataset_id=dataset.id,
document=document, document_id=dataset_document.id
nodes=nodes ).all()
)
documents = []
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
document = Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
documents.append(document)
def run_in_indexing_status(self, document: Document): # build index
"""Run the indexing process when the index_status is indexing.""" self._build_index(
# get dataset dataset=dataset,
dataset = Dataset.query.filter_by( dataset_document=dataset_document,
id=document.dataset_id documents=documents
).first() )
except DocumentIsPausedException:
if not dataset: raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
raise ValueError("no dataset found") except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
# get exist document_segment list and delete dataset_document.error = str(e.description)
document_segments = DocumentSegment.query.filter_by( dataset_document.stopped_at = datetime.datetime.utcnow()
dataset_id=dataset.id, db.session.commit()
document_id=document.id except Exception as e:
).all() logging.exception("consume document failed")
nodes = [] dataset_document.indexing_status = 'error'
if document_segments: dataset_document.error = str(e)
for document_segment in document_segments: dataset_document.stopped_at = datetime.datetime.utcnow()
# transform segment to node db.session.commit()
if document_segment.status != "completed":
relationships = {
DocumentRelationship.SOURCE: document_segment.document_id,
}
previous_segment = document_segment.previous_segment
if previous_segment:
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
next_segment = document_segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=document_segment.index_node_id,
doc_hash=document_segment.index_node_hash,
text=document_segment.content,
extra_info=None,
node_info=None,
relationships=relationships
)
nodes.append(node)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
""" """
...@@ -179,28 +214,28 @@ class IndexingRunner: ...@@ -179,28 +214,28 @@ class IndexingRunner:
total_segments = 0 total_segments = 0
for file_detail in file_details: for file_detail in file_details:
# load data from file # load data from file
text_docs = self._load_data_from_file(file_detail) text_docs = FileExtractor.load(file_detail)
processing_rule = DatasetProcessRule( processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"]) rules=json.dumps(tmp_processing_rule["rules"])
) )
# get node parser for splitting # get splitter
node_parser = self._get_node_parser(processing_rule) splitter = self._get_splitter(processing_rule)
# split to nodes # split to documents
nodes = self._split_to_nodes( documents = self._split_to_documents(
text_docs=text_docs, text_docs=text_docs,
node_parser=node_parser, splitter=splitter,
processing_rule=processing_rule processing_rule=processing_rule
) )
total_segments += len(nodes) total_segments += len(documents)
for node in nodes: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(node.get_text()) preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
return { return {
"total_segments": total_segments, "total_segments": total_segments,
...@@ -230,35 +265,36 @@ class IndexingRunner: ...@@ -230,35 +265,36 @@ class IndexingRunner:
).first() ).first()
if not data_source_binding: if not data_source_binding:
raise ValueError('Data source binding not found.') raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']: for page in notion_info['pages']:
if page['type'] == 'page': loader = NotionLoader(
page_ids = [page['page_id']] notion_access_token=data_source_binding.access_token,
documents = reader.load_data_as_documents(page_ids=page_ids) notion_workspace_id=workspace_id,
elif page['type'] == 'database': notion_obj_id=page['page_id'],
documents = reader.load_data_as_documents(database_id=page['page_id']) notion_page_type=page['type']
else: )
documents = [] documents = loader.load()
processing_rule = DatasetProcessRule( processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"]) rules=json.dumps(tmp_processing_rule["rules"])
) )
# get node parser for splitting # get splitter
node_parser = self._get_node_parser(processing_rule) splitter = self._get_splitter(processing_rule)
# split to nodes # split to documents
nodes = self._split_to_nodes( documents = self._split_to_documents(
text_docs=documents, text_docs=documents,
node_parser=node_parser, splitter=splitter,
processing_rule=processing_rule processing_rule=processing_rule
) )
total_segments += len(nodes) total_segments += len(documents)
for node in nodes: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(node.get_text()) preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
return { return {
"total_segments": total_segments, "total_segments": total_segments,
...@@ -268,14 +304,14 @@ class IndexingRunner: ...@@ -268,14 +304,14 @@ class IndexingRunner:
"preview": preview_texts "preview": preview_texts
} }
def _load_data(self, document: Document) -> List[Document]: def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
# load file # load file
if document.data_source_type not in ["upload_file", "notion_import"]: if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return [] return []
data_source_info = document.data_source_info_dict data_source_info = dataset_document.data_source_info_dict
text_docs = [] text_docs = []
if document.data_source_type == 'upload_file': if dataset_document.data_source_type == 'upload_file':
if not data_source_info or 'upload_file_id' not in data_source_info: if not data_source_info or 'upload_file_id' not in data_source_info:
raise ValueError("no upload file found") raise ValueError("no upload file found")
...@@ -283,47 +319,28 @@ class IndexingRunner: ...@@ -283,47 +319,28 @@ class IndexingRunner:
filter(UploadFile.id == data_source_info['upload_file_id']). \ filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none() one_or_none()
text_docs = self._load_data_from_file(file_detail) text_docs = FileExtractor.load(file_detail)
elif document.data_source_type == 'notion_import': elif dataset_document.data_source_type == 'notion_import':
if not data_source_info or 'notion_page_id' not in data_source_info \ loader = NotionLoader.from_document(dataset_document)
or 'notion_workspace_id' not in data_source_info: text_docs = loader.load()
raise ValueError("no notion page found")
workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id']
page_type = data_source_info['type']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
if page_type == 'page':
# add page last_edited_time to data_source_info
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
elif page_type == 'database':
# add page last_edited_time to data_source_info
self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
# update document status to splitting # update document status to splitting
self._update_document_index_status( self._update_document_index_status(
document_id=document.id, document_id=dataset_document.id,
after_indexing_status="splitting", after_indexing_status="splitting",
extra_update_params={ extra_update_params={
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
Document.parsing_completed_at: datetime.datetime.utcnow() DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
} }
) )
# replace doc id to document model id # replace doc id to document model id
text_docs = cast(List[Document], text_docs)
for text_doc in text_docs: for text_doc in text_docs:
# remove invalid symbol # remove invalid symbol
text_doc.text = self.filter_string(text_doc.get_text()) text_doc.page_content = self.filter_string(text_doc.page_content)
text_doc.doc_id = document.id text_doc.metadata['document_id'] = dataset_document.id
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
return text_docs return text_docs
...@@ -331,61 +348,7 @@ class IndexingRunner: ...@@ -331,61 +348,7 @@ class IndexingRunner:
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]') pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
return pattern.sub('', text) return pattern.sub('', text)
def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]: def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self.storage.download(upload_file.key, filepath)
file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
file_extractor[".markdown"] = MarkdownParser()
file_extractor[".md"] = MarkdownParser()
file_extractor[".html"] = HTMLParser()
file_extractor[".htm"] = HTMLParser()
file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
file_extractor[".xlsx"] = XLSXParser()
loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
text_docs = loader.load_data()
return text_docs
def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
page_ids = [page_id]
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs
def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(database_id=database_id)
return text_docs
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_page_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_database_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
""" """
...@@ -414,68 +377,83 @@ class IndexingRunner: ...@@ -414,68 +377,83 @@ class IndexingRunner:
separators=["\n\n", "。", ".", " ", ""] separators=["\n\n", "。", ".", " ", ""]
) )
return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True) return character_splitter
def _step_split(self, text_docs: List[Document], node_parser: NodeParser, def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]: dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
-> List[Document]:
""" """
Split the text documents into nodes and save them to the document segment. Split the text documents into documents and save them to the document segment.
""" """
nodes = self._split_to_nodes( documents = self._split_to_documents(
text_docs=text_docs, text_docs=text_docs,
node_parser=node_parser, splitter=splitter,
processing_rule=processing_rule processing_rule=processing_rule
) )
# save node to document segment # save node to document segment
doc_store = DatesetDocumentStore( doc_store = DatesetDocumentStore(
dataset=dataset, dataset=dataset,
user_id=document.created_by, user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name, embedding_model_name=self.embedding_model_name,
document_id=document.id document_id=dataset_document.id
) )
# add document segments # add document segments
doc_store.add_documents(nodes) doc_store.add_documents(documents)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.utcnow() cur_time = datetime.datetime.utcnow()
self._update_document_index_status( self._update_document_index_status(
document_id=document.id, document_id=dataset_document.id,
after_indexing_status="indexing", after_indexing_status="indexing",
extra_update_params={ extra_update_params={
Document.cleaning_completed_at: cur_time, DatasetDocument.cleaning_completed_at: cur_time,
Document.splitting_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time,
} }
) )
# update segment status to indexing # update segment status to indexing
self._update_segments_by_document( self._update_segments_by_document(
document_id=document.id, dataset_document_id=dataset_document.id,
update_params={ update_params={
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow() DocumentSegment.indexing_at: datetime.datetime.utcnow()
} }
) )
return nodes return documents
def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser, def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Node]: processing_rule: DatasetProcessRule) -> List[Document]:
""" """
Split the text documents into nodes. Split the text documents into nodes.
""" """
all_nodes = [] all_documents = []
for text_doc in text_docs: for text_doc in text_docs:
# document clean # document clean
document_text = self._document_clean(text_doc.get_text(), processing_rule) document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.text = document_text text_doc.page_content = document_text
# parse document to nodes # parse document to nodes
nodes = node_parser.get_nodes_from_documents([text_doc]) documents = splitter.split_documents([text_doc])
nodes = [node for node in nodes if node.text is not None and node.text.strip()]
all_nodes.extend(nodes) split_documents = []
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
split_documents.append(document)
all_documents.extend(split_documents)
return all_nodes return all_documents
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
""" """
...@@ -506,37 +484,38 @@ class IndexingRunner: ...@@ -506,37 +484,38 @@ class IndexingRunner:
return text return text
def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None: def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
""" """
Build the index for the document. Build the index for the document.
""" """
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
chunk_size = 100 chunk_size = 100
for i in range(0, len(nodes), chunk_size): for i in range(0, len(documents), chunk_size):
# check document is paused # check document is paused
self._check_document_paused_status(document.id) self._check_document_paused_status(dataset_document.id)
chunk_nodes = nodes[i:i + chunk_size] chunk_documents = documents[i:i + chunk_size]
tokens += sum( tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
for document in chunk_documents
) )
# save vector index # save vector index
if dataset.indexing_technique == "high_quality": if vector_index:
vector_index.add_nodes(chunk_nodes) vector_index.add_texts(chunk_documents)
# save keyword index # save keyword index
keyword_table_index.add_nodes(chunk_nodes) keyword_table_index.add_texts(chunk_documents)
node_ids = [node.doc_id for node in chunk_nodes] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.index_node_id.in_(node_ids), DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing" DocumentSegment.status == "indexing"
).update({ ).update({
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
...@@ -549,12 +528,12 @@ class IndexingRunner: ...@@ -549,12 +528,12 @@ class IndexingRunner:
# update document status to completed # update document status to completed
self._update_document_index_status( self._update_document_index_status(
document_id=document.id, document_id=dataset_document.id,
after_indexing_status="completed", after_indexing_status="completed",
extra_update_params={ extra_update_params={
Document.tokens: tokens, DatasetDocument.tokens: tokens,
Document.completed_at: datetime.datetime.utcnow(), DatasetDocument.completed_at: datetime.datetime.utcnow(),
Document.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
} }
) )
...@@ -569,25 +548,25 @@ class IndexingRunner: ...@@ -569,25 +548,25 @@ class IndexingRunner:
""" """
Update the document indexing status. Update the document indexing status.
""" """
count = Document.query.filter_by(id=document_id, is_paused=True).count() count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0: if count > 0:
raise DocumentIsPausedException() raise DocumentIsPausedException()
update_params = { update_params = {
Document.indexing_status: after_indexing_status DatasetDocument.indexing_status: after_indexing_status
} }
if extra_update_params: if extra_update_params:
update_params.update(extra_update_params) update_params.update(extra_update_params)
Document.query.filter_by(id=document_id).update(update_params) DatasetDocument.query.filter_by(id=document_id).update(update_params)
db.session.commit() db.session.commit()
def _update_segments_by_document(self, document_id: str, update_params: dict) -> None: def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
""" """
Update the document segment by document id. Update the document segment by document id.
""" """
DocumentSegment.query.filter_by(document_id=document_id).update(update_params) DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit() db.session.commit()
......
from typing import Union, Optional from typing import Union, Optional, List
from langchain.callbacks import CallbackManager from langchain.callbacks.base import BaseCallbackHandler
from langchain.llms.fake import FakeListLLM
from core.constant import llm_constant from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError from core.llm.error import ProviderTokenNotInitError
...@@ -32,12 +31,11 @@ class LLMBuilder: ...@@ -32,12 +31,11 @@ class LLMBuilder:
""" """
@classmethod @classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]: def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
if model_name == 'fake':
return FakeListLLM(responses=[])
provider = cls.get_default_provider(tenant_id) provider = cls.get_default_provider(tenant_id)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
mode = cls.get_mode_by_model(model_name) mode = cls.get_mode_by_model(model_name)
if mode == 'chat': if mode == 'chat':
if provider == 'openai': if provider == 'openai':
...@@ -52,16 +50,21 @@ class LLMBuilder: ...@@ -52,16 +50,21 @@ class LLMBuilder:
else: else:
raise ValueError(f"model name {model_name} is not supported.") raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
model_kwargs = {
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
}
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
return llm_cls( return llm_cls(
model_name=model_name, model_name=model_name,
temperature=kwargs.get('temperature', 0), temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256), max_tokens=kwargs.get('max_tokens', 256),
top_p=kwargs.get('top_p', 1), **model_extras_kwargs,
frequency_penalty=kwargs.get('frequency_penalty', 0), callbacks=kwargs.get('callbacks', None),
presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None),
streaming=kwargs.get('streaming', False), streaming=kwargs.get('streaming', False),
# request_timeout=None # request_timeout=None
**model_credentials **model_credentials
...@@ -69,7 +72,7 @@ class LLMBuilder: ...@@ -69,7 +72,7 @@ class LLMBuilder:
@classmethod @classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name") model_name = model.get("name")
completion_params = model.get("completion_params", {}) completion_params = model.get("completion_params", {})
...@@ -82,7 +85,7 @@ class LLMBuilder: ...@@ -82,7 +85,7 @@ class LLMBuilder:
frequency_penalty=completion_params.get('frequency_penalty', 0.1), frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1), presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming, streaming=streaming,
callback_manager=callback_manager callbacks=callbacks
) )
@classmethod @classmethod
......
...@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider): ...@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
""" """
config = self.get_provider_api_key(model_id=model_id) config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure' config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id.replace('.', '') if model_id else None if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config return config
def get_provider_name(self): def get_provider_name(self):
......
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
...@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return message_tokens return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(messages, stop) return await super().agenerate(messages, stop, callbacks, **kwargs)
import os from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any from typing import Optional, List, Dict, Mapping, Any
...@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI): ...@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(prompts, stop) return await super().agenerate(prompts, stop, callbacks, **kwargs)
import os import os
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
...@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI): ...@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
return message_tokens return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(messages, stop) return await super().agenerate(messages, stop, callbacks, **kwargs)
import os import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List, Dict, Any, Mapping from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI from langchain import OpenAI
...@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI): ...@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
}} }}
@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @handle_llm_exceptions_async
async def agenerate( async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return await super().agenerate(prompts, stop) return await super().agenerate(prompts, stop, callbacks, **kwargs)
from typing import Any, List, Dict from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel from langchain.schema import get_buffer_string
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
......
from llama_index import QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT = ( CONVERSATION_TITLE_PROMPT = (
"Human:{query}\n-----\n" "Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
...@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( ...@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[\"question1\",\"question2\",\"question3\"]\n" "[\"question1\",\"question2\",\"question3\"]\n"
) )
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.\n"
"---------------------\n"
"{question}\n"
"---------------------\n"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
the model prompt that best suits the input. the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement. You will be provided with the prompt, variables, and an opening statement.
......
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class DatasetTool(BaseTool):
"""Tool for querying a Dataset."""
dataset: Dataset
k: int = 2
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
from typing import Optional
from langchain.callbacks import CallbackManager
from llama_index.langchain_helpers.agents import IndexToolConfig
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from models.dataset import Dataset
class DatasetToolBuilder:
@classmethod
def build_dataset_tool(cls, dataset: Dataset,
response_mode: str = "no_synthesizer",
callback_handler: Optional[DatasetToolCallbackHandler] = None):
if dataset.indexing_technique == "economy":
# use keyword table query
index = KeywordTableIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
"max_keywords_per_query": 5,
# If num_chunks_per_query is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"num_chunks_per_query": 2
}
else:
index = VectorIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
# If top_k is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"similarity_top_k": 2
}
# fulfill description when it is empty
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
index_tool_config = IndexToolConfig(
index=index,
name=f"dataset-{dataset.id}",
description=description,
index_query_kwargs=query_kwargs,
tool_kwargs={
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
},
# tool_kwargs={"return_direct": True},
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
return EnhanceLlamaIndexTool.from_tool_config(
tool_config=index_tool_config,
callback_handler=index_callback_handler
)
from typing import Dict
from langchain.tools import BaseTool
from llama_index.indices.base import BaseGPTIndex
from llama_index.langchain_helpers.agents import IndexToolConfig
from pydantic import Field
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
class EnhanceLlamaIndexTool(BaseTool):
"""Tool for querying a LlamaIndex."""
# NOTE: name/description still needs to be set
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
return_sources: bool = False
callback_handler: IndexToolCallbackHandler
@classmethod
def from_tool_config(cls, tool_config: IndexToolConfig,
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
"""Create a tool from a tool config."""
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
return cls(
index=tool_config.index,
callback_handler=callback_handler,
name=tool_config.name,
description=tool_config.description,
return_sources=return_sources,
query_kwargs=tool_config.index_query_kwargs,
**tool_config.tool_kwargs,
)
def _run(self, tool_input: str) -> str:
response = self.index.query(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)
async def _arun(self, tool_input: str) -> str:
response = await self.index.aquery(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)
from abc import ABC, abstractmethod
from typing import Optional
from llama_index import ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs import Node
from llama_index.vector_stores.types import VectorStore
class BaseVectorStoreClient(ABC):
@abstractmethod
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
raise NotImplementedError
@abstractmethod
def to_index_config(self, index_id: str) -> dict:
raise NotImplementedError
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
def delete_node(self, node_id: str):
self._vector_store.delete_node(node_id)
def exists_by_node_id(self, node_id: str) -> bool:
return self._vector_store.exists_by_node_id(node_id)
class EnhanceVectorStore(ABC):
@abstractmethod
def delete_node(self, node_id: str):
pass
@abstractmethod
def exists_by_node_id(self, node_id: str) -> bool:
pass
from typing import cast, Any
from langchain.schema import Document
from langchain.vectorstores import Qdrant
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal
class QdrantVectorStore(Qdrant):
def del_texts(self, filter: Filter):
if not filter:
raise ValueError('filter must not be empty')
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def del_text(self, uuid: str) -> None:
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(
points=[uuid],
),
)
def text_exists(self, uuid: str) -> bool:
self._reload_if_needed()
response = self.client.retrieve(
collection_name=self.collection_name,
ids=[uuid]
)
return len(response) > 0
def delete(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
if scored_point.payload.get('doc_id'):
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata={'doc_id': scored_point.id}
)
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
def _reload_if_needed(self):
if isinstance(self.client, QdrantLocal):
self.client = cast(QdrantLocal, self.client)
self.client._load()
import os
from typing import cast, List
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
from qdrant_client.http.models import Payload, Filter
import qdrant_client
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
from llama_index.vector_stores import QdrantVectorStore
from qdrant_client.local.qdrant_local import QdrantLocal
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class QdrantVectorStoreClient(BaseVectorStoreClient):
def __init__(self, url: str, api_key: str, root_path: str):
self._client = self.init_from_config(url, api_key, root_path)
@classmethod
def init_from_config(cls, url: str, api_key: str, root_path: str):
if url and url.startswith('path:'):
path = url.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(root_path, path)
return qdrant_client.QdrantClient(
path=path
)
else:
return qdrant_client.QdrantClient(
url=url,
api_key=api_key,
)
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = QdrantIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"collection_name": "Gpt_index_xxx"}
collection_name = config.get('collection_name')
if not collection_name:
raise Exception("collection_name cannot be None.")
return GPTQdrantEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=QdrantEnhanceVectorStore(
client=self._client,
collection_name=collection_name
)
)
def to_index_config(self, index_id: str) -> dict:
return {"collection_name": index_id}
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
pass
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
from qdrant_client.http import models as rest
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=rest.Filter(
must=[
rest.FieldCondition(
key="id", match=rest.MatchValue(value=node_id)
)
]
),
)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
self._reload_if_needed()
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[node_id]
)
return len(response) > 0
def query(
self,
query: VectorStoreQuery,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding = cast(List[float], query.query_embedding)
self._reload_if_needed()
response = self._client.search(
collection_name=self._collection_name,
query_vector=query_embedding,
limit=cast(int, query.similarity_top_k),
query_filter=cast(Filter, self._build_query_filter(query)),
with_vectors=True
)
nodes = []
similarities = []
ids = []
for point in response:
payload = cast(Payload, point.payload)
node = Node(
doc_id=str(point.id),
text=payload.get("text"),
embedding=point.vector,
extra_info=payload.get("extra_info"),
relationships={
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
},
)
nodes.append(node)
similarities.append(point.score)
ids.append(str(point.id))
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
def _reload_if_needed(self):
if isinstance(self._client._client, QdrantLocal):
self._client._client._load()
from flask import Flask
from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
class VectorStore:
def __init__(self):
self._vector_store = None
self._client = None
def init_app(self, app: Flask):
if not app.config['VECTOR_STORE']:
return
self._vector_store = app.config['VECTOR_STORE']
if self._vector_store not in SUPPORTED_VECTOR_STORES:
raise ValueError(f"Vector store {self._vector_store} is not supported.")
if self._vector_store == 'weaviate':
self._client = WeaviateVectorStoreClient(
endpoint=app.config['WEAVIATE_ENDPOINT'],
api_key=app.config['WEAVIATE_API_KEY'],
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
batch_size=app.config['WEAVIATE_BATCH_SIZE']
)
elif self._vector_store == 'qdrant':
self._client = QdrantVectorStoreClient(
url=app.config['QDRANT_URL'],
api_key=app.config['QDRANT_API_KEY'],
root_path=app.root_path
)
app.extensions['vector_store'] = self
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
vector_store_config: dict = index_struct.get('vector_store')
index = self.get_client().get_index(
service_context=service_context,
config=vector_store_config
)
return index
def to_index_struct(self, index_id: str) -> dict:
return {
"type": self._vector_store,
"vector_store": self.get_client().to_index_config(index_id)
}
def get_client(self):
if not self._client:
raise Exception("Vector store client is not initialized.")
return self._client
from llama_index.indices.query.base import IS
from typing import (
Any,
Dict,
List,
Optional
)
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
from langchain.vectorstores import Weaviate
class WeaviateVectorStore(Weaviate):
def del_texts(self, where_filter: dict):
if not where_filter:
raise ValueError('where_filter must not be empty')
self._client.batch.delete_objects(
class_name=self._index_name,
where=where_filter,
output='minimal'
)
def del_text(self, uuid: str) -> None:
self._client.data_object.delete(
uuid,
class_name=self._index_name
)
def text_exists(self, uuid: str) -> bool:
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": uuid,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][self._index_name]
if len(entries) == 0:
return False
return True
def delete(self):
self._client.schema.delete_class(self._index_name)
import json
import weaviate
from dataclasses import field
from typing import List, Any, Dict, Optional
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import (
parse_get_response,
validate_client,
)
class WeaviateVectorStoreClient(BaseVectorStoreClient):
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
weaviate.connect.connection.has_grpc = grpc_enabled
client = weaviate.Client(
url=endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = WeaviateIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"class_prefix": "Gpt_index_xxx"}
class_prefix = config.get('class_prefix')
if not class_prefix:
raise Exception("class_prefix cannot be None.")
return GPTWeaviateEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=WeaviateWithSimilaritiesVectorStore(
weaviate_client=self._client,
class_prefix=class_prefix
)
)
def to_index_config(self, index_id: str) -> dict:
return {"class_prefix": index_id}
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
nodes = self.weaviate_query(
self._client,
self._class_prefix,
query,
)
nodes = nodes[: query.similarity_top_k]
node_idxs = [str(i) for i in range(len(nodes))]
similarities = []
for node in nodes:
similarities.append(node.extra_info['similarity'])
del node.extra_info['similarity']
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
def weaviate_query(
self,
client: Any,
class_prefix: str,
query_spec: VectorStoreQuery,
) -> List[Node]:
"""Convert to LlamaIndex list."""
validate_client(client)
class_name = _class_name(class_prefix)
prop_names = [p["name"] for p in NODE_SCHEMA]
vector = query_spec.query_embedding
# build query
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
_logger.debug("Using vector search")
if vector is not None:
query = query.with_near_vector(
{
"vector": vector,
}
)
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
query = query.with_hybrid(
query=query_spec.query_str,
alpha=query_spec.alpha,
vector=vector,
)
query = query.with_limit(query_spec.similarity_top_k)
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
# execute query
query_result = query.do()
# parse results
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
results = [self._to_node(entry) for entry in entries]
return results
def _to_node(self, entry: Dict) -> Node:
"""Convert to Node."""
extra_info_str = entry["extra_info"]
if extra_info_str == "":
extra_info = None
else:
extra_info = json.loads(extra_info_str)
if 'certainty' in entry['_additional']:
if extra_info:
extra_info['similarity'] = entry['_additional']['certainty']
else:
extra_info = {'similarity': entry['_additional']['certainty']}
node_info_str = entry["node_info"]
if node_info_str == "":
node_info = None
else:
node_info = json.loads(node_info_str)
relationships_str = entry["relationships"]
relationships: Dict[DocumentRelationship, str]
if relationships_str == "":
relationships = field(default_factory=dict)
else:
relationships = {
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
}
return Node(
text=entry["text"],
doc_id=entry["doc_id"],
embedding=entry["_additional"]["vector"],
extra_info=extra_info,
node_info=node_info,
relationships=relationships,
)
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document(self._client, doc_id, self._class_prefix)
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node(self._client, node_id, self._class_prefix)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry = get_by_node_id(self._client, node_id, self._class_prefix)
return True if entry else False
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
pass
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["ref_doc_id"],
"operator": "Equal",
"valueString": ref_doc_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
while len(entries) > 0:
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
if len(entries) == 0:
return None
return entries[0]
from core.vector_store.vector_store import VectorStore
vector_store = VectorStore()
def init_app(app):
vector_store.init_app(app)
...@@ -3,6 +3,7 @@ import re ...@@ -3,6 +3,7 @@ import re
import subprocess import subprocess
import uuid import uuid
from datetime import datetime from datetime import datetime
from hashlib import sha256
from zoneinfo import available_timezones from zoneinfo import available_timezones
import random import random
import string import string
...@@ -147,3 +148,8 @@ def get_remote_ip(request): ...@@ -147,3 +148,8 @@ def get_remote_ip(request):
return request.headers.getlist("X-Forwarded-For")[0] return request.headers.getlist("X-Forwarded-For")[0]
else: else:
return request.remote_addr return request.remote_addr
def generate_text_hash(text: str) -> str:
hash_text = str(text) + 'None'
return sha256(hash_text.encode()).hexdigest()
...@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model): ...@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
_current_tenant: db.Model = None
@property @property
def current_tenant(self): def current_tenant(self):
return self._current_tenant return self._current_tenant
......
...@@ -66,6 +66,23 @@ class Dataset(db.Model): ...@@ -66,6 +66,23 @@ class Dataset(db.Model):
def document_count(self): def document_count(self):
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
@property
def available_document_count(self):
return db.session.query(func.count(Document.id)).filter(
Document.dataset_id == self.id,
Document.indexing_status == 'completed',
Document.enabled == True,
Document.archived == False
).scalar()
@property
def available_segment_count(self):
return db.session.query(func.count(DocumentSegment.id)).filter(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).scalar()
@property @property
def word_count(self): def word_count(self):
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
...@@ -260,7 +277,7 @@ class Document(db.Model): ...@@ -260,7 +277,7 @@ class Document(db.Model):
@property @property
def dataset(self): def dataset(self):
return Dataset.query.get(self.dataset_id) return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
@property @property
def segment_count(self): def segment_count(self):
...@@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model): ...@@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model):
@property @property
def keyword_table_dict(self): def keyword_table_dict(self):
return json.loads(self.keyword_table) if self.keyword_table else None class SetDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, dct):
if isinstance(dct, dict):
for keyword, node_idxs in dct.items():
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct
return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
class Embedding(db.Model): class Embedding(db.Model):
......
...@@ -2,6 +2,7 @@ coverage~=7.2.4 ...@@ -2,6 +2,7 @@ coverage~=7.2.4
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
flask~=2.3.2 flask~=2.3.2
Flask-SQLAlchemy~=3.0.3 Flask-SQLAlchemy~=3.0.3
SQLAlchemy~=1.4.28
flask-login==0.6.2 flask-login==0.6.2
flask-migrate~=4.0.4 flask-migrate~=4.0.4
flask-restful==0.3.9 flask-restful==0.3.9
...@@ -9,8 +10,7 @@ flask-session2==1.3.1 ...@@ -9,8 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10 flask-cors==3.0.10
gunicorn~=20.1.0 gunicorn~=20.1.0
gevent~=22.10.2 gevent~=22.10.2
langchain==0.0.142 langchain==0.0.209
llama-index==0.5.27
openai~=0.27.5 openai~=0.27.5
psycopg2-binary~=2.9.6 psycopg2-binary~=2.9.6
pycryptodome==3.17 pycryptodome==3.17
...@@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1 ...@@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1
jieba==0.42.1 jieba==0.42.1
celery==5.2.7 celery==5.2.7
redis~=4.5.4 redis~=4.5.4
pypdf==3.8.1
openpyxl==3.1.2 openpyxl==3.1.2
chardet~=5.1.0 chardet~=5.1.0
\ No newline at end of file docx2txt==0.8
pypdfium2==4.16.0
\ No newline at end of file
...@@ -4,7 +4,6 @@ import uuid ...@@ -4,7 +4,6 @@ import uuid
from core.constant import llm_constant from core.constant import llm_constant
from models.account import Account from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class AppModelConfigService: class AppModelConfigService:
......
...@@ -7,7 +7,6 @@ from typing import Optional, List ...@@ -7,7 +7,6 @@ from typing import Optional, List
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
from core.index.index_builder import IndexBuilder
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
...@@ -386,8 +385,6 @@ class DocumentService: ...@@ -386,8 +385,6 @@ class DocumentService:
dataset.indexing_technique = document_data["indexing_technique"] dataset.indexing_technique = document_data["indexing_technique"]
if dataset.indexing_technique == 'high_quality':
IndexBuilder.get_default_service_context(dataset.tenant_id)
documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
if 'original_document_id' in document_data and document_data["original_document_id"]: if 'original_document_id' in document_data and document_data["original_document_id"]:
......
...@@ -3,47 +3,56 @@ import time ...@@ -3,47 +3,56 @@ import time
from typing import List from typing import List
import numpy as np import numpy as np
from llama_index.data_structs.node_v2 import NodeWithScore from flask import current_app
from llama_index.indices.query.schema import QueryBundle from langchain.embeddings import OpenAIEmbeddings
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from core.docstore.empty_docstore import EmptyDocumentStore from core.embedding.cached_embedding import CacheEmbedding
from core.index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, DocumentSegment, DatasetQuery from models.dataset import Dataset, DocumentSegment, DatasetQuery
from services.errors.index import IndexNotInitializedError
class HitTestingService: class HitTestingService:
@classmethod @classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
index = VectorIndex(dataset=dataset).query_index if dataset.available_document_count == 0 or dataset.available_document_count == 0:
return {
if not index: "query": {
raise IndexNotInitializedError() "content": query,
"tsne_position": {'x': 0, 'y': 0},
index_query = GPTVectorStoreIndexQuery( },
index_struct=index.index_struct, "records": []
service_context=index.service_context, }
vector_store=index.query_context.get('vector_store'),
docstore=EmptyDocumentStore(),
response_synthesizer=None,
similarity_top_k=limit
)
query_bundle = QueryBundle( model_credentials = LLMBuilder.get_model_credentials(
query_str=query, tenant_id=dataset.tenant_id,
custom_embedding_strs=[query], model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
) )
query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries( embeddings = CacheEmbedding(OpenAIEmbeddings(
query_bundle.embedding_strs **model_credentials
))
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
) )
start = time.perf_counter() start = time.perf_counter()
nodes = index_query.retrieve(query_bundle=query_bundle) documents = vector_index.search(
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10
}
)
end = time.perf_counter() end = time.perf_counter()
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
...@@ -58,25 +67,24 @@ class HitTestingService: ...@@ -58,25 +67,24 @@ class HitTestingService:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
return cls.compact_retrieve_response(dataset, query_bundle, nodes) return cls.compact_retrieve_response(dataset, embeddings, query, documents)
@classmethod @classmethod
def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]): def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
embeddings = [ text_embeddings = [
query_bundle.embedding embeddings.embed_query(query)
] ]
for node in nodes: text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))
embeddings.append(node.node.embedding)
tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)
query_position = tsne_position_data.pop(0) query_position = tsne_position_data.pop(0)
i = 0 i = 0
records = [] records = []
for node in nodes: for document in documents:
index_node_id = node.node.doc_id index_node_id = document.metadata['doc_id']
segment = db.session.query(DocumentSegment).filter( segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
...@@ -91,7 +99,7 @@ class HitTestingService: ...@@ -91,7 +99,7 @@ class HitTestingService:
record = { record = {
"segment": segment, "segment": segment,
"score": node.score, "score": document.metadata['score'],
"tsne_position": tsne_position_data[i] "tsne_position": tsne_position_data[i]
} }
...@@ -101,7 +109,7 @@ class HitTestingService: ...@@ -101,7 +109,7 @@ class HitTestingService:
return { return {
"query": { "query": {
"content": query_bundle.query_str, "content": query,
"tsne_position": query_position, "tsne_position": query_position,
}, },
"records": records "records": records
......
...@@ -4,96 +4,81 @@ import time ...@@ -4,96 +4,81 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from llama_index.data_structs import Node from langchain.schema import Document
from llama_index.data_structs.node_v2 import DocumentRelationship
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task @shared_task
def add_document_to_index_task(document_id: str): def add_document_to_index_task(dataset_document_id: str):
""" """
Async Add document to index Async Add document to index
:param document_id: :param document_id:
Usage: add_document_to_index.delay(document_id) Usage: add_document_to_index.delay(document_id)
""" """
logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green')) logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green'))
start_at = time.perf_counter() start_at = time.perf_counter()
document = db.session.query(Document).filter(Document.id == document_id).first() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
if not document: if not dataset_document:
raise NotFound('Document not found') raise NotFound('Document not found')
if document.indexing_status != 'completed': if dataset_document.indexing_status != 'completed':
return return
indexing_cache_key = 'document_{}_indexing'.format(document.id) indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id)
try: try:
segments = db.session.query(DocumentSegment).filter( segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True DocumentSegment.enabled == True
) \ ) \
.order_by(DocumentSegment.position.asc()).all() .order_by(DocumentSegment.position.asc()).all()
nodes = [] documents = []
previous_node = None
for segment in segments: for segment in segments:
relationships = { document = Document(
DocumentRelationship.SOURCE: document.id page_content=segment.content,
} metadata={
"doc_id": segment.index_node_id,
if previous_node: "doc_hash": segment.index_node_hash,
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id "document_id": segment.document_id,
"dataset_id": segment.dataset_id,
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id }
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
) )
previous_node = node documents.append(document)
nodes.append(node) dataset = dataset_document.dataset
dataset = document.dataset
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
# save vector index # save vector index
if dataset.indexing_technique == "high_quality": index = IndexBuilder.get_index(dataset, 'high_quality')
vector_index.add_nodes( if index:
nodes=nodes, index.add_texts(documents)
duplicate_check=True
)
# save keyword index # save keyword index
keyword_table_index.add_nodes(nodes) index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green'))
except Exception as e: except Exception as e:
logging.exception("add document to index failed") logging.exception("add document to index failed")
document.enabled = False dataset_document.enabled = False
document.disabled_at = datetime.datetime.utcnow() dataset_document.disabled_at = datetime.datetime.utcnow()
document.status = 'error' dataset_document.status = 'error'
document.error = str(e) dataset_document.error = str(e)
db.session.commit() db.session.commit()
finally: finally:
redis_client.delete(indexing_cache_key) redis_client.delete(indexing_cache_key)
...@@ -4,12 +4,10 @@ import time ...@@ -4,12 +4,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from llama_index.data_structs import Node from langchain.schema import Document
from llama_index.data_structs.node_v2 import DocumentRelationship
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str): ...@@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str):
indexing_cache_key = 'segment_{}_indexing'.format(segment.id) indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
try: try:
relationships = { document = Document(
DocumentRelationship.SOURCE: segment.document_id, page_content=segment.content,
} metadata={
"doc_id": segment.index_node_id,
previous_segment = segment.previous_segment "doc_hash": segment.index_node_hash,
if previous_segment: "document_id": segment.document_id,
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id "dataset_id": segment.dataset_id,
}
next_segment = segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
) )
dataset = segment.dataset dataset = segment.dataset
if not dataset: if not dataset:
raise Exception('Segment has no dataset') logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
return
vector_index = VectorIndex(dataset=dataset) dataset_document = segment.document
keyword_table_index = KeywordTableIndex(dataset=dataset)
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return
# save vector index # save vector index
if dataset.indexing_technique == "high_quality": index = IndexBuilder.get_index(dataset, 'high_quality')
vector_index.add_nodes( if index:
nodes=[node], index.add_texts([document], duplicate_check=True)
duplicate_check=True
)
# save keyword index # save keyword index
keyword_table_index.add_nodes([node]) index = IndexBuilder.get_index(dataset, 'economy')
if index:
index.add_texts([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
...@@ -4,8 +4,7 @@ import time ...@@ -4,8 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
AppDatasetJoin AppDatasetJoin
...@@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct=index_struct index_struct=index_struct
) )
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
index_doc_ids = [document.id for document in documents]
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index vector_index = IndexBuilder.get_index(dataset, 'high_quality')
if dataset.indexing_technique == "high_quality": kw_index = IndexBuilder.get_index(dataset, 'economy')
for index_doc_id in index_doc_ids:
try:
vector_index.del_doc(index_doc_id)
except Exception:
logging.exception("Delete doc index failed when dataset deleted.")
continue
# delete from keyword index # delete from vector index
if index_node_ids: if vector_index:
try: try:
keyword_table_index.del_nodes(index_node_ids) vector_index.delete()
except Exception: except Exception:
logging.exception("Delete nodes index failed when dataset deleted.") logging.exception("Delete doc index failed when dataset deleted.")
# delete from keyword index
try:
kw_index.delete()
except Exception:
logging.exception("Delete nodes index failed when dataset deleted.")
for document in documents: for document in documents:
db.session.delete(document) db.session.delete(document)
...@@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, ...@@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete()
db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
......
...@@ -4,8 +4,7 @@ import time ...@@ -4,8 +4,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset from models.dataset import DocumentSegment, Dataset
...@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str): ...@@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
db.session.commit() db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
...@@ -5,8 +5,7 @@ from typing import List ...@@ -5,8 +5,7 @@ from typing import List
import click import click
from celery import shared_task from celery import shared_task
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Dataset, Document from models.dataset import DocumentSegment, Dataset, Document
...@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str): ...@@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
for document_id in document_ids: for document_id in document_ids:
document = db.session.query(Document).filter( document = db.session.query(Document).filter(
Document.id == document_id Document.id == document_id
).first() ).first()
db.session.delete(document) db.session.delete(document)
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
......
...@@ -3,10 +3,12 @@ import time ...@@ -3,10 +3,12 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from llama_index.data_structs.node_v2 import DocumentRelationship, Node from langchain.schema import Document
from core.index.vector_index import VectorIndex
from core.index.index import IndexBuilder
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Document, Dataset from models.dataset import DocumentSegment, Dataset
from models.dataset import Document as DatasetDocument
@shared_task @shared_task
...@@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): ...@@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
dataset = Dataset.query.filter_by( dataset = Dataset.query.filter_by(
id=dataset_id id=dataset_id
).first() ).first()
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
documents = Document.query.filter_by(dataset_id=dataset_id).all()
if documents: if action == "remove":
vector_index = VectorIndex(dataset=dataset) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
for document in documents: index.delete()
# delete from vector index elif action == "add":
if action == "remove": dataset_documents = db.session.query(DatasetDocument).filter(
vector_index.del_doc(document.id) DatasetDocument.dataset_id == dataset_id,
elif action == "add": DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
if dataset_documents:
# save vector index
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
for dataset_document in dataset_documents:
# delete from vector index
segments = db.session.query(DocumentSegment).filter( segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True DocumentSegment.enabled == True
) .order_by(DocumentSegment.position.asc()).all() ) .order_by(DocumentSegment.position.asc()).all()
nodes = [] documents = []
previous_node = None
for segment in segments: for segment in segments:
relationships = { document = Document(
DocumentRelationship.SOURCE: document.id page_content=segment.content,
} metadata={
"doc_id": segment.index_node_id,
if previous_node: "doc_hash": segment.index_node_hash,
relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id "document_id": segment.document_id,
"dataset_id": segment.dataset_id,
previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id }
node = Node(
doc_id=segment.index_node_id,
doc_hash=segment.index_node_hash,
text=segment.content,
extra_info=None,
node_info=None,
relationships=relationships
) )
previous_node = node documents.append(document)
nodes.append(node)
# save vector index # save vector index
vector_index.add_nodes( index.add_texts(documents)
nodes=nodes,
duplicate_check=True
)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
...@@ -6,11 +6,9 @@ import click ...@@ -6,11 +6,9 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.data_source.notion import NotionPageReader from core.data_loader.loader.notion import NotionLoader
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from core.indexing_runner import IndexingRunner, DocumentIsPausedException from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document, Dataset, DocumentSegment from models.dataset import Document, Dataset, DocumentSegment
from models.source import DataSourceBinding from models.source import DataSourceBinding
...@@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
raise ValueError("no notion page found") raise ValueError("no notion page found")
workspace_id = data_source_info['notion_workspace_id'] workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id'] page_id = data_source_info['notion_page_id']
page_type = data_source_info['type']
page_edited_time = data_source_info['last_edited_time'] page_edited_time = data_source_info['last_edited_time']
data_source_binding = DataSourceBinding.query.filter( data_source_binding = DataSourceBinding.query.filter(
db.and_( db.and_(
...@@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
).first() ).first()
if not data_source_binding: if not data_source_binding:
raise ValueError('Data source binding not found.') raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
last_edited_time = reader.get_page_last_edited_time(page_id) loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type
)
last_edited_time = loader.get_notion_last_edited_time()
# check the page is updated # check the page is updated
if last_edited_time != page_edited_time: if last_edited_time != page_edited_time:
document.indexing_status = 'parsing' document.indexing_status = 'parsing'
...@@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_document_id(document_id)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
...@@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ...@@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
except Exception: except Exception:
logging.exception("Cleaned document when document update data source or process rule failed") logging.exception("Cleaned document when document update data source or process rule failed")
try: try:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run([document]) indexing_runner.run([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except ProviderTokenNotInitError as e: except Exception:
document.indexing_status = 'error' pass
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume update document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
...@@ -7,7 +7,6 @@ from celery import shared_task ...@@ -7,7 +7,6 @@ from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.indexing_runner import IndexingRunner, DocumentIsPausedException from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document from models.dataset import Document
...@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list): ...@@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list):
Usage: document_indexing_task.delay(dataset_id, document_id) Usage: document_indexing_task.delay(dataset_id, document_id)
""" """
documents = [] documents = []
start_at = time.perf_counter()
for document_id in document_ids: for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
start_at = time.perf_counter()
document = db.session.query(Document).filter( document = db.session.query(Document).filter(
Document.id == document_id, Document.id == document_id,
...@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): ...@@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run(documents) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except ProviderTokenNotInitError as e: except Exception:
document.indexing_status = 'error' pass
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
...@@ -6,10 +6,8 @@ import click ...@@ -6,10 +6,8 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from core.indexing_runner import IndexingRunner, DocumentIsPausedException from core.indexing_runner import IndexingRunner, DocumentIsPausedException
from core.llm.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document, Dataset, DocumentSegment from models.dataset import Document, Dataset, DocumentSegment
...@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str): ...@@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
if not dataset: if not dataset:
raise Exception('Dataset not found') raise Exception('Dataset not found')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
vector_index.del_nodes(index_node_ids) if vector_index:
vector_index.delete_by_ids(index_node_ids)
# delete from keyword index # delete from keyword index
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)
...@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): ...@@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
except Exception: except Exception:
logging.exception("Cleaned document when document update data source or process rule failed") logging.exception("Cleaned document when document update data source or process rule failed")
try: try:
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
indexing_runner.run([document]) indexing_runner.run([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except ProviderTokenNotInitError as e: except Exception:
document.indexing_status = 'error' pass
document.error = str(e.description)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume update document failed")
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
import datetime
import logging import logging
import time import time
...@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): ...@@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
indexing_runner.run_in_indexing_status(document) indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green'))
except DocumentIsPausedException: except DocumentIsPausedException as ex:
logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) logging.info(click.style(str(ex), fg='yellow'))
except Exception as e: except Exception:
logging.exception("consume document failed") pass
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
...@@ -5,8 +5,7 @@ import click ...@@ -5,8 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment, Document
...@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str): ...@@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str):
if not dataset: if not dataset:
raise Exception('Document has no dataset') raise Exception('Document has no dataset')
vector_index = VectorIndex(dataset=dataset) vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = KeywordTableIndex(dataset=dataset) kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
vector_index.del_doc(document.id) vector_index.delete_by_document_id(document.id)
# delete from keyword index # delete from keyword index
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids: if index_node_ids:
keyword_table_index.del_nodes(index_node_ids) kw_index.delete_by_ids(index_node_ids)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
......
...@@ -5,8 +5,7 @@ import click ...@@ -5,8 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.index.keyword_table_index import KeywordTableIndex from core.index.index import IndexBuilder
from core.index.vector_index import VectorIndex
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
...@@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str): ...@@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str):
dataset = segment.dataset dataset = segment.dataset
if not dataset: if not dataset:
raise Exception('Segment has no dataset') logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan'))
return
vector_index = VectorIndex(dataset=dataset) dataset_document = segment.document
keyword_table_index = KeywordTableIndex(dataset=dataset)
if not dataset_document:
logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan'))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan'))
return
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if dataset.indexing_technique == "high_quality": if vector_index:
vector_index.del_nodes([segment.index_node_id]) vector_index.delete_by_ids([segment.index_node_id])
# delete from keyword index # delete from keyword index
keyword_table_index.del_nodes([segment.index_node_id]) kw_index.delete_by_ids([segment.index_node_id])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment