Unverified Commit 3241e401 authored by John Wang's avatar John Wang Committed by GitHub

feat: upgrade langchain (#430)

Co-authored-by: 's avatarjyong <718720800@qq.com>
parent 1dee5de9
...@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session ...@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
import flask_login import flask_login
from flask_cors import CORS from flask_cors import CORS
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \ from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage ext_database, ext_storage
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
...@@ -79,7 +79,6 @@ def initialize_extensions(app): ...@@ -79,7 +79,6 @@ def initialize_extensions(app):
ext_database.init_app(app) ext_database.init_app(app)
ext_migrate.init(app, db) ext_migrate.init(app, db)
ext_redis.init_app(app) ext_redis.init_app(app)
ext_vector_store.init_app(app)
ext_storage.init_app(app) ext_storage.init_app(app)
ext_celery.init_app(app) ext_celery.init_app(app)
ext_session.init_app(app) ext_session.init_app(app)
......
import datetime import datetime
import logging
import random import random
import string import string
import click import click
from flask import current_app from flask import current_app
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from libs.password import password_pattern, valid_password, hash_password from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate from libs.helper import email as email_validate
from extensions.ext_database import db from extensions.ext_database import db
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant from models.account import InvitationCode, Tenant
from models.dataset import Dataset
from models.model import Account from models.model import Account
import secrets import secrets
import base64 import base64
...@@ -159,8 +163,39 @@ def generate_upper_string(): ...@@ -159,8 +163,39 @@ def generate_upper_string():
return result return result
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
def recreate_all_dataset_indexes():
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
recreate_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
try:
click.echo('Recreating dataset index: {}'.format(dataset.id))
index = IndexBuilder.get_index(dataset, 'high_quality')
if index and index._is_origin():
index.recreate_dataset(dataset)
recreate_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes) app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)
...@@ -187,11 +187,13 @@ class Config: ...@@ -187,11 +187,13 @@ class Config:
# For temp use only # For temp use only
# set default LLM provider, default is 'openai', support `azure_openai` # set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting # notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
class CloudEditionConfig(Config): class CloudEditionConfig(Config):
......
...@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound ...@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.oauth_data_source import NotionOAuth
from models.dataset import Document from models.dataset import Document
from models.source import DataSourceBinding from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
...@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource): ...@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
).first() ).first()
if not data_source_binding: if not data_source_binding:
raise NotFound('Data source binding not found.') raise NotFound('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
if page_type == 'page': loader = NotionLoader(
page_content = reader.read_page(page_id) notion_access_token=data_source_binding.access_token,
elif page_type == 'database': notion_workspace_id=workspace_id,
page_content = reader.query_database_data(page_id) notion_obj_id=page_id,
else: notion_page_type=page_type
page_content = "" )
text_docs = loader.load()
return { return {
'content': page_content 'content': "\n".join([doc.page_content for doc in text_docs])
}, 200 }, 200
@setup_required @setup_required
......
...@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles ...@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError UnsupportedFileTypeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.index.readers.html_parser import HTMLParser from core.data_loader.file_extractor import FileExtractor
from core.index.readers.pdf_parser import PDFParser
from core.index.readers.xlsx_parser import XLSXParser
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs.helper import TimestampField from libs.helper import TimestampField
from extensions.ext_database import db from extensions.ext_database import db
...@@ -123,31 +121,7 @@ class FilePreviewApi(Resource): ...@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if extension not in ALLOWED_EXTENSIONS: if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
with tempfile.TemporaryDirectory() as temp_dir: text = FileExtractor.load(upload_file, return_text=True)
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, filepath)
if extension == 'pdf':
parser = PDFParser({'upload_file': upload_file})
text = parser.parse_file(Path(filepath))
elif extension in ['html', 'htm']:
# Use BeautifulSoup to extract text
parser = HTMLParser()
text = parser.parse_file(Path(filepath))
elif extension == 'xlsx':
parser = XLSXParser()
text = parser.parse_file(filepath)
else:
# ['txt', 'markdown', 'md']
with open(filepath, "rb") as fp:
data = fp.read()
encoding = chardet.detect(data)['encoding']
if encoding:
text = data.decode(encoding=encoding).strip() if data else ''
else:
text = data.decode(encoding='utf-8').strip() if data else ''
text = text[0:PREVIEW_WORDS_LIMIT] if text else '' text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text} return {'content': text}
......
...@@ -32,8 +32,13 @@ class VersionApi(Resource): ...@@ -32,8 +32,13 @@ class VersionApi(Resource):
'current_version': args.get('current_version') 'current_version': args.get('current_version')
}) })
except Exception as error: except Exception as error:
logging.exception("Check update error.") logging.warning("Check update version error: {}.".format(str(error)))
raise InternalServerError() return {
'version': args.get('current_version'),
'release_date': '',
'release_notes': '',
'can_auto_update': False
}
content = json.loads(response.content) content = json.loads(response.content)
return { return {
......
...@@ -3,19 +3,11 @@ from typing import Optional ...@@ -3,19 +3,11 @@ from typing import Optional
import langchain import langchain
from flask import Flask from flask import Flask
from jieba.analyse import default_tfidf
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from llama_index import IndexStructType, QueryMode
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
from pydantic import BaseModel from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.index.keyword_table.stopwords import STOPWORDS
from core.prompt.prompt_template import OneLineFormatter from core.prompt.prompt_template import OneLineFormatter
from core.vector_store.vector_store import VectorStore
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
class HostedOpenAICredential(BaseModel): class HostedOpenAICredential(BaseModel):
...@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials() ...@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask): def init_app(app: Flask):
formatter = OneLineFormatter() formatter = OneLineFormatter()
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
default_tfidf.stop_words = STOPWORDS
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True langchain.verbose = True
set_handler(DifyStdOutCallbackHandler())
if app.config.get("OPENAI_API_KEY"): if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
...@@ -2,7 +2,7 @@ from typing import Optional ...@@ -2,7 +2,7 @@ from typing import Optional
from langchain import LLMChain from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
...@@ -16,23 +16,20 @@ class AgentBuilder: ...@@ -16,23 +16,20 @@ class AgentBuilder:
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler, dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm( llm = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name, model_name=agent_loop_gather_callback_handler.model_name,
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
callback_manager=llm_callback_manager callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
) )
tool_callback_manager = CallbackManager([
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
])
for tool in tools: for tool in tools:
tool.callback_manager = tool_callback_manager tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
]
prompt = cls.build_agent_prompt_template( prompt = cls.build_agent_prompt_template(
tools=tools, tools=tools,
...@@ -54,7 +51,7 @@ class AgentBuilder: ...@@ -54,7 +51,7 @@ class AgentBuilder:
tools=tools, tools=tools,
agent=agent, agent=agent,
memory=memory, memory=memory,
callback_manager=agent_callback_manager, callbacks=agent_callback_manager,
max_iterations=6, max_iterations=6,
early_stopping_method="generate", early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
......
...@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask ...@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
class AgentLoopGatherCallbackHandler(BaseCallbackHandler): class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
...@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completion = response.generations[0][0].text self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
...@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start( def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
...@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end.""" """Run on agent end."""
# Final Answer # Final Answer
......
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
...@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask ...@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler): class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
...@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): ...@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Do nothing.""" """Do nothing."""
logging.error(error) logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass
from llama_index import Response from typing import List
from langchain.schema import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
class IndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
def __init__(self) -> None:
self._response = None
@property
def response(self) -> Response:
return self._response
def on_tool_end(self, response: Response) -> None:
"""Handle tool end."""
self._response = response
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None: def __init__(self, dataset_id: str) -> None:
super().__init__()
self.dataset_id = dataset_id self.dataset_id = dataset_id
def on_tool_end(self, response: Response) -> None: def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end.""" """Handle tool end."""
for node in response.source_nodes: for document in documents:
index_node_id = node.node.doc_id doc_id = document.metadata['doc_id']
# add hit count to document segment # add hit count to document segment
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id, DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == index_node_id DocumentSegment.index_node_id == doc_id
).update( ).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False synchronize_session=False
......
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
...@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI ...@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler): class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
...@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Whether to call verbose callbacks even if verbose is False.""" """Whether to call verbose callbacks even if verbose is False."""
return True return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
role = 'user'
elif message.type == 'ai':
role = 'assistant'
else:
role = 'system'
real_prompts.append({
"role": role,
"text": message.content
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
self.start_at = time.perf_counter() self.start_at = time.perf_counter()
if 'Chat' in serialized['name']: self.llm_message.prompt = [{
real_prompts = [] "role": 'user',
messages = [] "text": prompts[0]
for prompt in prompts: }]
role, content = prompt.split(': ', maxsplit=1)
if role == 'human':
role = 'user'
message = HumanMessage(content=content)
elif role == 'ai':
role = 'assistant'
message = AIMessage(content=content)
else:
message = SystemMessage(content=content)
real_prompt = {
"role": role,
"text": content
}
real_prompts.append(real_prompt)
messages.append(message)
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
else:
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter() end_at = time.perf_counter()
...@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else: else:
logging.error(error) logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
pass
import logging import logging
import time import time
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult from core.callback_handler.entity.chain_result import ChainResult
...@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask ...@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
class MainChainGatherCallbackHandler(BaseCallbackHandler): class MainChainGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
...@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
if not self._current_chain_result: if not self._current_chain_result:
self._current_chain_result = ChainResult( chain_type = serialized['id'][-1]
type=serialized['name'], if chain_type:
prompt=inputs, self._current_chain_result = ChainResult(
started_at=time.perf_counter() type=chain_type,
) prompt=inputs,
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) started_at=time.perf_counter()
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message )
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
...@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): ...@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
logging.error(error) logging.error(error)
self.clear_chain_results() self.clear_chain_results()
\ No newline at end of file
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.error(error)
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass
import os
import sys import sys
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
class DifyStdOutCallbackHandler(BaseCallbackHandler): class DifyStdOutCallbackHandler(BaseCallbackHandler):
...@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler.""" """Initialize callback handler."""
self.color = color self.color = color
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
for sub_messages in messages:
for sub_message in sub_messages:
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
"""Print out the prompts.""" """Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue') print_text("\n[on_llm_start]\n", color='blue')
print_text(prompts[0] + "\n", color='blue')
if 'Chat' in serialized['name']:
for prompt in prompts:
print_text(prompt + "\n", color='blue')
else:
print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing.""" """Do nothing."""
...@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Print out that we are entering a chain.""" """Print out that we are entering a chain."""
class_name = serialized["name"] chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink') print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
...@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): ...@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent end.""" """Run on agent end."""
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming.""" """Callback handler for streaming. Only works with LLMs that support streaming."""
......
from typing import Optional from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain from core.chain.tool_chain import ToolChain
...@@ -14,7 +12,7 @@ class ChainBuilder: ...@@ -14,7 +12,7 @@ class ChainBuilder:
tool=tool, tool=tool,
input_key=kwargs.get('input_key', 'input'), input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'), output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) callbacks=[DifyStdOutCallbackHandler()]
) )
@classmethod @classmethod
...@@ -27,7 +25,7 @@ class ChainBuilder: ...@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words=sensitive_words.split(","), sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''), canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output", output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]), callbacks=[DifyStdOutCallbackHandler()],
**kwargs **kwargs
) )
......
"""Base classes for LLM-powered router chains.""" """Base classes for LLM-powered router chains."""
from __future__ import annotations from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from pydantic import root_validator from pydantic import root_validator
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown from libs.json_in_md_parser import parse_and_check_json_markdown
...@@ -51,8 +52,9 @@ class LLMRouterChain(Chain): ...@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise ValueError raise ValueError
def _call( def _call(
self, self,
inputs: Dict[str, Any] inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
output = cast( output = cast(
Dict[str, Any], Dict[str, Any],
......
from typing import Optional, List from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain from langchain.chains import SequentialChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder from core.chain.chain_builder import ChainBuilder
...@@ -18,6 +16,7 @@ from models.dataset import Dataset ...@@ -18,6 +16,7 @@ from models.dataset import Dataset
class MainChainBuilder: class MainChainBuilder:
@classmethod @classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
first_input_key = "input" first_input_key = "input"
final_output_key = "output" final_output_key = "output"
...@@ -30,6 +29,7 @@ class MainChainBuilder: ...@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains, chains_output_key = cls.get_agent_chains( tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id, tenant_id=tenant_id,
agent_mode=agent_mode, agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory, memory=memory,
conversation_message_task=conversation_message_task conversation_message_task=conversation_message_task
) )
...@@ -42,9 +42,8 @@ class MainChainBuilder: ...@@ -42,9 +42,8 @@ class MainChainBuilder:
return None return None
for chain in chains: for chain in chains:
# do not add handler into singleton callback manager chain = cast(Chain, chain)
if not isinstance(chain.callback_manager, SharedCallbackManager): chain.callbacks.append(chain_callback_handler)
chain.callback_manager.add_handler(chain_callback_handler)
# build main chain # build main chain
overall_chain = SequentialChain( overall_chain = SequentialChain(
...@@ -57,7 +56,9 @@ class MainChainBuilder: ...@@ -57,7 +56,9 @@ class MainChainBuilder:
return overall_chain return overall_chain
@classmethod @classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask): conversation_message_task: ConversationMessageTask):
# agent mode # agent mode
chains = [] chains = []
...@@ -93,7 +94,8 @@ class MainChainBuilder: ...@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id=tenant_id, tenant_id=tenant_id,
datasets=datasets, datasets=datasets,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
) )
chains.append(multi_dataset_router_chain) chains.append(multi_dataset_router_chain)
......
import math
from typing import Mapping, List, Dict, Any, Optional from typing import Mapping, List, Dict, Any, Optional
from langchain import LLMChain, PromptTemplate, ConversationChain from langchain import PromptTemplate
from langchain.callbacks import CallbackManager from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
...@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan ...@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_tool_builder import DatasetToolBuilder from core.tool.dataset_index_tool import DatasetTool
from core.tool.llama_index_tool import EnhanceLlamaIndexTool from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Dataset
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """ MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \ Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \ the input. You will be given the names of the available prompts and a description of \
...@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain): ...@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
router_chain: LLMRouterChain router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it.""" """Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, EnhanceLlamaIndexTool] dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to.""" """Map of name to candidate chains that inputs can be routed to."""
class Config: class Config:
...@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain): ...@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
tenant_id: str, tenant_id: str,
datasets: List[Dataset], datasets: List[Dataset],
conversation_message_task: ConversationMessageTask, conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any, **kwargs: Any,
): ):
"""Convenience constructor for instantiating from destination prompts.""" """Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm( llm = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
model_name='gpt-3.5-turbo', model_name='gpt-3.5-turbo',
temperature=0, temperature=0,
max_tokens=1024, max_tokens=1024,
callback_manager=llm_callback_manager callbacks=[DifyStdOutCallbackHandler()]
) )
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name)) else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets] for d in datasets]
destinations_str = "\n".join(destinations) destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str destinations=destinations_str
) )
router_prompt = PromptTemplate( router_prompt = PromptTemplate(
template=router_template, template=router_template,
input_variables=["input"], input_variables=["input"],
output_parser=RouterOutputParser(), output_parser=RouterOutputParser(),
) )
router_chain = LLMRouterChain.from_llm(llm, router_prompt) router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {} dataset_tools = {}
for dataset in datasets: for dataset in datasets:
dataset_tool = DatasetToolBuilder.build_dataset_tool( # fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset, dataset=dataset,
response_mode='no_synthesizer', # "compact" callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
) )
if dataset_tool: dataset_tools[str(dataset.id)] = dataset_tool
dataset_tools[dataset.id] = dataset_tool
return cls( return cls(
router_chain=router_chain, router_chain=router_chain,
...@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain): ...@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
**kwargs, **kwargs,
) )
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call( def _call(
self, self,
inputs: Dict[str, Any] inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if len(self.dataset_tools) == 0: if len(self.dataset_tools) == 0:
return {"text": ''} return {"text": ''}
......
from typing import List, Dict from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
...@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain): ...@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return self.canned_response return self.canned_response
return text return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key] text = inputs[self.input_key]
output = self._check_sensitive_word(text) output = self._check_sensitive_word(text)
return {self.output_key: output} return {self.output_key: output}
from typing import List, Dict from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.tools import BaseTool from langchain.tools import BaseTool
...@@ -30,12 +31,20 @@ class ToolChain(Chain): ...@@ -30,12 +31,20 @@ class ToolChain(Chain):
""" """
return [self.output_key] return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key] input = inputs[self.input_key]
output = self.tool.run(input, self.verbose) output = self.tool.run(input, self.verbose)
return {self.output_key: output} return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
input = inputs[self.input_key] input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose) output = await self.tool.arun(input, self.verbose)
......
import logging import logging
from typing import Optional, List, Union, Tuple from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder from core.chain.main_chain_builder import MainChainBuilder
...@@ -34,8 +35,6 @@ class Completion: ...@@ -34,8 +35,6 @@ class Completion:
""" """
errors: ProviderTokenNotInitError errors: ProviderTokenNotInitError
""" """
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
memory = None memory = None
if conversation: if conversation:
# get memory of conversation (read-only) # get memory of conversation (read-only)
...@@ -48,6 +47,14 @@ class Completion: ...@@ -48,6 +47,14 @@ class Completion:
inputs = conversation.inputs inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
conversation_message_task = ConversationMessageTask( conversation_message_task = ConversationMessageTask(
task_id=task_id, task_id=task_id,
app=app, app=app,
...@@ -64,6 +71,7 @@ class Completion: ...@@ -64,6 +71,7 @@ class Completion:
main_chain = MainChainBuilder.to_langchain_components( main_chain = MainChainBuilder.to_langchain_components(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict, agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task conversation_message_task=conversation_message_task
) )
...@@ -115,7 +123,7 @@ class Completion: ...@@ -115,7 +123,7 @@ class Completion:
memory=memory memory=memory
) )
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task) final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=final_llm, final_llm=final_llm,
...@@ -247,16 +255,14 @@ And answer according to the language of the user's question. ...@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:'] return messages, ['\nHuman:']
@classmethod @classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool, streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager: conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming: if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else: else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()] return [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
...@@ -293,7 +299,8 @@ And answer according to the language of the user's question. ...@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
return memory return memory
@classmethod @classmethod
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str): def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model( llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id, tenant_id=tenant_id,
model=app_model_config.model_dict model=app_model_config.model_dict
...@@ -302,8 +309,26 @@ And answer according to the language of the user's question. ...@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens max_tokens = llm.max_tokens
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0: # get prompt without memory and context
raise LLMBadRequestError("Query is too long") prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod @classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
...@@ -360,7 +385,7 @@ And answer according to the language of the user's question. ...@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
streaming=streaming streaming=streaming
) )
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task) llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=llm, final_llm=llm,
......
...@@ -293,12 +293,12 @@ class PubHandler: ...@@ -293,12 +293,12 @@ class PubHandler:
if not user: if not user:
raise ValueError("user is required") raise ValueError("user is required")
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result:{}-{}".format(user_str, task_id) return "generate_result:{}-{}".format(user_str, task_id)
@classmethod @classmethod
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result_stopped:{}-{}".format(user_str, task_id) return "generate_result_stopped:{}-{}".format(user_str, task_id)
def pub_text(self, text: str): def pub_text(self, text: str):
...@@ -306,10 +306,10 @@ class PubHandler: ...@@ -306,10 +306,10 @@ class PubHandler:
'event': 'message', 'event': 'message',
'data': { 'data': {
'task_id': self._task_id, 'task_id': self._task_id,
'message_id': self._message.id, 'message_id': str(self._message.id),
'text': text, 'text': text,
'mode': self._conversation.mode, 'mode': self._conversation.mode,
'conversation_id': self._conversation.id 'conversation_id': str(self._conversation.id)
} }
} }
......
import tempfile
from pathlib import Path
from typing import List, Union
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from extensions.ext_storage import storage
from models.model import UploadFile
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
delimiter = '\n'
if input_file.suffix == '.xlsx':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
import logging
from typing import Optional, Dict, List
from langchain.document_loaders import CSVLoader as LCCSVLoader
from langchain.document_loaders.helpers import detect_file_encodings
from models.dataset import Document
logger = logging.getLogger(__name__)
class CSVLoader(LCCSVLoader):
def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
self.file_path = file_path
self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
"""Load data into document objects."""
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def _read_from_file(self, csvfile):
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
import json
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from openpyxl.reader.excel import load_workbook
logger = logging.getLogger(__name__)
class ExcelLoader(BaseLoader):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
data = []
keys = []
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, row))
row_dict = {k: v for k, v in row_dict.items() if v}
data.append(json.dumps(row_dict, ensure_ascii=False))
return [Document(page_content='\n\n'.join(data))]
import logging
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class HTMLLoader(BaseLoader):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
"""Markdown parser. import logging
import re
from typing import Optional, List, Tuple, cast
Contains parser for md files. from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class MarkdownLoader(BaseLoader):
"""Load md files.
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from llama_index.readers.file.base_parser import BaseParser Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
class MarkdownParser(BaseParser): remove_images: Whether to remove images from the text.
"""Markdown parser.
Extract text from markdown files. encoding: File encoding to use. If `None`, the file will be loaded
Returns dictionary with keys as headers and values as the text between headers. with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
""" """
def __init__( def __init__(
self, self,
*args: Any, file_path: str,
remove_hyperlinks: bool = True, remove_hyperlinks: bool = True,
remove_images: bool = True, remove_images: bool = True,
**kwargs: Any, encoding: Optional[str] = None,
) -> None: autodetect_encoding: bool = True,
"""Init params.""" ):
super().__init__(*args, **kwargs) """Initialize with file path."""
self._file_path = file_path
self._remove_hyperlinks = remove_hyperlinks self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images self._remove_images = remove_images
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
tups = self.parse_tups(self._file_path)
documents = []
for header, value in tups:
value = value.strip()
if header is None:
documents.append(Document(page_content=value))
else:
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
return documents
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary. """Convert a markdown file to a dictionary.
...@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser): ...@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
content = re.sub(pattern, r"\1", content) content = re.sub(pattern, r"\1", content)
return content return content
def _init_parser(self) -> Dict: def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
"""Initialize the parser with the config."""
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples.""" """Parse file into tuples."""
with open(filepath, "r", encoding="utf-8") as f: content = ""
content = f.read() try:
with open(filepath, "r", encoding=self._encoding) as f:
content = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {filepath}") from e
except Exception as e:
raise RuntimeError(f"Error loading {filepath}") from e
if self._remove_hyperlinks: if self._remove_hyperlinks:
content = self.remove_hyperlinks(content) content = self.remove_hyperlinks(content)
if self._remove_images: if self._remove_images:
content = self.remove_images(content) content = self.remove_images(content)
markdown_tups = self.markdown_to_tups(content)
return markdown_tups
def parse_file( return self.markdown_to_tups(content)
self, filepath: Path, errors: str = "ignore"
) -> Union[str, List[str]]:
"""Parse file into string."""
tups = self.parse_tups(filepath, errors=errors)
results = []
# TODO: don't include headers right now
for header, value in tups:
if header is None:
results.append(value)
else:
results.append(f"\n\n{header}\n{value}")
return results
import logging
from typing import List, Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> List[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
from typing import Any, Dict, Optional, Sequence from typing import Any, Dict, Optional, Sequence
import tiktoken from langchain.schema import Document
from llama_index.data_structs import Node
from llama_index.docstore.types import BaseDocumentStore
from llama_index.docstore.utils import json_to_doc
from llama_index.schema import BaseDocument
from sqlalchemy import func from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator from core.llm.token_calculator import TokenCalculator
...@@ -12,7 +8,7 @@ from extensions.ext_database import db ...@@ -12,7 +8,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore(BaseDocumentStore): class DatesetDocumentStore:
def __init__( def __init__(
self, self,
dataset: Dataset, dataset: Dataset,
...@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return self._embedding_model_name return self._embedding_model_name
@property @property
def docs(self) -> Dict[str, BaseDocument]: def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter( document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id DocumentSegment.dataset_id == self._dataset.id
).all() ).all()
...@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output = {} output = {}
for document_segment in document_segments: for document_segment in document_segments:
doc_id = document_segment.index_node_id doc_id = document_segment.index_node_id
result = self.segment_to_dict(document_segment) output[doc_id] = Document(
output[doc_id] = json_to_doc(result) page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
return output return output
def add_documents( def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True self, docs: Sequence[Document], allow_update: bool = True
) -> None: ) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter( max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id DocumentSegment.document == self._document_id
...@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position = 0 max_position = 0
for doc in docs: for doc in docs:
if doc.is_doc_id_none: if not isinstance(doc, Document):
raise ValueError("doc_id not set") raise ValueError("doc must be a Document")
if not isinstance(doc, Node): segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
raise ValueError("doc must be a Node")
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
# NOTE: doc could already exist in the store, but we overwrite it # NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document: if not allow_update and segment_document:
raise ValueError( raise ValueError(
f"doc_id {doc.get_doc_id()} already exists. " f"doc_id {doc.metadata['doc_id']} already exists. "
"Set allow_update to True to overwrite." "Set allow_update to True to overwrite."
) )
# calc embedding use tokens # calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text()) tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
if not segment_document: if not segment_document:
max_position += 1 max_position += 1
...@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id=self._dataset.tenant_id, tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id, dataset_id=self._dataset.id,
document_id=self._document_id, document_id=self._document_id,
index_node_id=doc.get_doc_id(), index_node_id=doc.metadata['doc_id'],
index_node_hash=doc.get_doc_hash(), index_node_hash=doc.metadata['doc_hash'],
position=max_position, position=max_position,
content=doc.get_text(), content=doc.page_content,
word_count=len(doc.get_text()), word_count=len(doc.page_content),
tokens=tokens, tokens=tokens,
created_by=self._user_id, created_by=self._user_id,
) )
db.session.add(segment_document) db.session.add(segment_document)
else: else:
segment_document.content = doc.get_text() segment_document.content = doc.page_content
segment_document.index_node_hash = doc.get_doc_hash() segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.get_text()) segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens segment_document.tokens = tokens
db.session.commit() db.session.commit()
...@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def get_document( def get_document(
self, doc_id: str, raise_error: bool = True self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]: ) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id) document_segment = self.get_document_segment(doc_id)
if document_segment is None: if document_segment is None:
...@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else: else:
return None return None
result = self.segment_to_dict(document_segment) return Document(
return json_to_doc(result) page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None: def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
document_segment = self.get_document_segment(doc_id) document_segment = self.get_document_segment(doc_id)
...@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return document_segment.index_node_hash return document_segment.index_node_hash
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
def get_document_segment(self, doc_id: str) -> DocumentSegment: def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter( document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.dataset_id == self._dataset.id,
...@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore): ...@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
).first() ).first()
return document_segment return document_segment
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
return {
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"text": segment.content,
"__type__": Node.get_type()
}
from typing import Any, Dict, Optional, Sequence
from llama_index.docstore.types import BaseDocumentStore
from llama_index.schema import BaseDocument
class EmptyDocumentStore(BaseDocumentStore):
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {}
@property
def docs(self) -> Dict[str, BaseDocument]:
return {}
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
pass
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return False
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
return None
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
pass
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
pass
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
return None
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
import logging
from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings = []
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
i += 1
text_embeddings.extend(embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to db')
return embedding_results
from typing import Optional, Any, List
import openai
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
_TEXT_MODE_MODEL_DICT
from tenacity import wait_random_exponential, retry, stop_after_attempt
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
text: str,
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[float]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding"
]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
class OpenAIEmbedding(BaseEmbedding):
def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
new_kwargs = {}
if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']
super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name
self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')
@handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(self._get_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(await self._aget_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
from __future__ import annotations
from abc import abstractmethod, ABC
from typing import List, Any
from langchain.schema import Document, BaseRetriever
from models.dataset import Dataset
class BaseIndex(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError
@abstractmethod
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError
@abstractmethod
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
\ No newline at end of file
from langchain.callbacks import CallbackManager
from llama_index import ServiceContext, PromptHelper, LLMPredictor
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.embedding.openai_embedding import OpenAIEmbedding
from core.llm.llm_builder import LLMBuilder
class IndexBuilder:
@classmethod
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
# set number of output tokens
num_output = 512
# only for verbose
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='text-davinci-003',
temperature=0,
max_tokens=num_output,
callback_manager=callback_manager,
)
llm_predictor = LLMPredictor(llm=llm)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper = PromptHelper(
max_input_size=3500,
num_output=num_output,
max_chunk_overlap=20
)
provider = LLMBuilder.get_default_provider(tenant_id)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002'
)
return ServiceContext.from_defaults(
llm_predictor=llm_predictor,
prompt_helper=prompt_helper,
embed_model=OpenAIEmbedding(**model_credentials),
)
@classmethod
def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='fake'
)
return ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
import re
from typing import (
Any,
Dict,
List,
Set,
Optional
)
import jieba.analyse
from core.index.keyword_table.stopwords import STOPWORDS
from llama_index.indices.query.base import IS
from llama_index import QueryMode
from llama_index.indices.base import QueryMap
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
def jieba_extract_keywords(
text_chunk: str,
max_keywords: Optional[int] = None,
expand_with_subtokens: bool = True,
) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text_chunk,
topK=max_keywords,
)
if expand_with_subtokens:
return set(expand_tokens_with_subtokens(keywords))
else:
return set(keywords)
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def _extract_keywords(self, text: str) -> Set[str]:
"""Extract keywords from text."""
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
@classmethod
def get_query_map(self) -> QueryMap:
"""Get query map."""
super_map = super().get_query_map()
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
return super_map
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete = {doc_id}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in self._index_struct.table.items():
if node_idxs_to_delete.intersection(node_idxs):
self._index_struct.table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not self._index_struct.table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del self._index_struct.table[keyword]
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)
import json
from typing import List, Optional
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
from llama_index.data_structs import KeywordTable, Node
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.registry import load_index_struct_from_dict
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.index_builder import IndexBuilder
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
index_struct = KeywordTable()
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node in nodes:
keywords = index._extract_keywords(node.get_text())
self.update_segment_keywords(node.doc_id, list(keywords))
index._index_struct.add_node(list(keywords), node)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
def del_nodes(self, node_ids: List[str]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node_id in node_ids:
index.delete(node_id)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
@property
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
docstore = DatesetDocumentStore(
dataset=self._dataset,
user_id=self._dataset.created_by,
embedding_model_name="text-embedding-ada-002"
)
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return None
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
def get_keyword_table(self):
dataset_keyword_table = self._dataset.dataset_keyword_table
if dataset_keyword_table:
return dataset_keyword_table
return None
def update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
import re
from typing import Set
import jieba
from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
\ No newline at end of file
import json
from collections import defaultdict
from typing import Any, List, Optional, Dict
from langchain.schema import Document, BaseRetriever
from pydantic import BaseModel, Field, Extra
from core.index.base import BaseIndex
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
super().__init__(dataset)
self._config = config
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
return documents
def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
}
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
else:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del keyword_table[keyword]
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
list(chunk_indices_count.keys()),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
from typing import (
Any,
Dict,
Optional, Sequence,
)
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from llama_index.types import RESPONSE_TEXT_TYPE
class EnhanceResponseSynthesizer(ResponseSynthesizer):
@classmethod
def from_args(
cls,
service_context: ServiceContext,
streaming: bool = False,
use_async: bool = False,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_mode: ResponseMode = ResponseMode.DEFAULT,
response_kwargs: Optional[Dict] = None,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
) -> "ResponseSynthesizer":
response_builder: Optional[BaseResponseBuilder] = None
if response_mode != ResponseMode.NO_TEXT:
if response_mode == 'no_synthesizer':
response_builder = NoSynthesizer(
service_context=service_context,
simple_template=simple_template,
streaming=streaming,
)
else:
response_builder = get_response_builder(
service_context,
text_qa_template,
refine_template,
simple_template,
response_mode,
use_async=use_async,
streaming=streaming,
)
return cls(response_builder, response_mode, response_kwargs, optimizer)
class NoSynthesizer(BaseResponseBuilder):
def __init__(
self,
service_context: ServiceContext,
simple_template: Optional[SimpleInputPrompt] = None,
streaming: bool = False,
) -> None:
super().__init__(service_context, streaming)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
\ No newline at end of file
from pathlib import Path
from typing import Dict
from bs4 import BeautifulSoup
from llama_index.readers.file.base_parser import BaseParser
class HTMLParser(BaseParser):
"""HTML parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
with open(file, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
from pathlib import Path
from typing import Dict
from flask import current_app
from llama_index.readers.file.base_parser import BaseParser
from pypdf import PdfReader
from extensions.ext_storage import storage
from models.model import UploadFile
class PDFParser(BaseParser):
"""PDF parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
if not current_app.config.get('PDF_PREVIEW', True):
return ''
plaintext_file_key = ''
plaintext_file_exists = False
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
upload_file: UploadFile = self._parser_config['upload_file']
if upload_file.hash:
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return text
except FileNotFoundError:
pass
text_list = []
with open(file, "rb") as fp:
# Create a PDF object
pdf = PdfReader(fp)
# Get the number of pages in the PDF document
num_pages = len(pdf.pages)
# Iterate over every page
for page in range(num_pages):
# Extract the text from the page
page_text = pdf.pages[page].extract_text()
text_list.append(page_text)
text = "\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return text
from pathlib import Path
import json
from typing import Dict
from openpyxl import load_workbook
from llama_index.readers.file.base_parser import BaseParser
from flask import current_app
class XLSXParser(BaseParser):
"""XLSX parser."""
def _init_parser(self) -> Dict:
"""Init parser"""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
data = []
keys = []
with open(file, "r") as fp:
wb = load_workbook(filename=file, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, row))
row_dict = {k: v for k, v in row_dict.items() if v}
data.append(json.dumps(row_dict, ensure_ascii=False))
return '\n\n'.join(data)
import json
import logging
from typing import List, Optional
from llama_index.data_structs import Node
from requests import ReadTimeout
from sqlalchemy.exc import IntegrityError
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from core.index.index_builder import IndexBuilder
from core.vector_store.base import BaseGPTVectorStoreIndex
from extensions.ext_vector_store import vector_store
from extensions.ext_database import db
from models.dataset import Dataset, Embedding
class VectorIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
if duplicate_check:
nodes = self._filter_duplicate_nodes(index, nodes)
embedding_queue_nodes = []
embedded_nodes = []
for node in nodes:
node_hash = node.doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
if embedding:
node.embedding = embedding.get_embedding()
embedded_nodes.append(node)
else:
embedding_queue_nodes.append(node)
if embedding_queue_nodes:
embedding_results = index._get_node_embedding_results(
embedding_queue_nodes,
set(),
)
# pre embed nodes for cached embedding
for embedding_result in embedding_results:
node = embedding_result.node
node.embedding = embedding_result.embedding
try:
embedding = Embedding(hash=node.doc_hash)
embedding.set_embedding(node.embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
embedded_nodes.append(node)
self.index_insert_nodes(index, embedded_nodes)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
index.insert_nodes(nodes)
def del_nodes(self, node_ids: List[str]):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
for node_id in node_ids:
self.index_delete_node(index, node_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
index.delete_node(node_id)
def del_doc(self, doc_id: str):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
self.index_delete_doc(index, doc_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
index.delete(doc_id)
@property
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
if not self._dataset.index_struct_dict:
return None
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
return vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
for node in nodes:
node_id = node.doc_id
exists_duplicate_node = index.exists_by_node_id(node_id)
if exists_duplicate_node:
nodes.remove(node)
return nodes
import json
import logging
from abc import abstractmethod
from typing import List, Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.info(f"Recreating dataset {dataset.id}")
try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
origin_index_struct = self.dataset.index_struct
self.dataset.index_struct = None
if documents:
try:
self.create(documents)
except Exception as e:
self.dataset.index_struct = origin_index_struct
raise e
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
import os
from typing import Optional, Any, List, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if class_prefix.startswith('Vector_'):
# original class_prefix
return True
return False
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError(f"Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
from typing import Optional, cast
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
This diff is collapsed.
This diff is collapsed.
...@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider): ...@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
""" """
config = self.get_provider_api_key(model_id=model_id) config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure' config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id.replace('.', '') if model_id else None if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config return config
def get_provider_name(self): def get_provider_name(self):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from typing import Any, List, Dict from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel from langchain.schema import get_buffer_string
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from core.vector_store.vector_store import VectorStore
vector_store = VectorStore()
def init_app(app):
vector_store.init_app(app)
This diff is collapsed.
...@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model): ...@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
_current_tenant: db.Model = None
@property @property
def current_tenant(self): def current_tenant(self):
return self._current_tenant return self._current_tenant
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment