Unverified Commit 3241e401 authored by John Wang's avatar John Wang Committed by GitHub

feat: upgrade langchain (#430)

Co-authored-by: 's avatarjyong <718720800@qq.com>
parent 1dee5de9
......@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
import flask_login
from flask_cors import CORS
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage
from extensions.ext_database import db
from extensions.ext_login import login_manager
......@@ -79,7 +79,6 @@ def initialize_extensions(app):
ext_database.init_app(app)
ext_migrate.init(app, db)
ext_redis.init_app(app)
ext_vector_store.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_session.init_app(app)
......
import datetime
import logging
import random
import string
import click
from flask import current_app
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant
from models.dataset import Dataset
from models.model import Account
import secrets
import base64
......@@ -159,8 +163,39 @@ def generate_upper_string():
return result
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
def recreate_all_dataset_indexes():
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
recreate_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
try:
click.echo('Recreating dataset index: {}'.format(dataset.id))
index = IndexBuilder.get_index(dataset, 'high_quality')
if index and index._is_origin():
index.recreate_dataset(dataset)
recreate_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)
......@@ -187,11 +187,13 @@ class Config:
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
class CloudEditionConfig(Config):
......
......@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.oauth_data_source import NotionOAuth
from models.dataset import Document
from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService
......@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
).first()
if not data_source_binding:
raise NotFound('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
if page_type == 'page':
page_content = reader.read_page(page_id)
elif page_type == 'database':
page_content = reader.query_database_data(page_id)
else:
page_content = ""
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type
)
text_docs = loader.load()
return {
'content': page_content
'content': "\n".join([doc.page_content for doc in text_docs])
}, 200
@setup_required
......
......@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.index.readers.xlsx_parser import XLSXParser
from core.data_loader.file_extractor import FileExtractor
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db
......@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, filepath)
if extension == 'pdf':
parser = PDFParser({'upload_file': upload_file})
text = parser.parse_file(Path(filepath))
elif extension in ['html', 'htm']:
# Use BeautifulSoup to extract text
parser = HTMLParser()
text = parser.parse_file(Path(filepath))
elif extension == 'xlsx':
parser = XLSXParser()
text = parser.parse_file(filepath)
else:
# ['txt', 'markdown', 'md']
with open(filepath, "rb") as fp:
data = fp.read()
encoding = chardet.detect(data)['encoding']
if encoding:
text = data.decode(encoding=encoding).strip() if data else ''
else:
text = data.decode(encoding='utf-8').strip() if data else ''
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text}
......
......@@ -32,8 +32,13 @@ class VersionApi(Resource):
'current_version': args.get('current_version')
})
except Exception as error:
logging.exception("Check update error.")
raise InternalServerError()
logging.warning("Check update version error: {}.".format(str(error)))
return {
'version': args.get('current_version'),
'release_date': '',
'release_notes': '',
'can_auto_update': False
}
content = json.loads(response.content)
return {
......
......@@ -3,19 +3,11 @@ from typing import Optional
import langchain
from flask import Flask
from jieba.analyse import default_tfidf
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from llama_index import IndexStructType, QueryMode
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.index.keyword_table.stopwords import STOPWORDS
from core.prompt.prompt_template import OneLineFormatter
from core.vector_store.vector_store import VectorStore
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
class HostedOpenAICredential(BaseModel):
......@@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
formatter = OneLineFormatter()
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
default_tfidf.stop_words = STOPWORDS
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
set_handler(DifyStdOutCallbackHandler())
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
......@@ -2,7 +2,7 @@ from typing import Optional
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
......@@ -16,23 +16,20 @@ class AgentBuilder:
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name,
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
tool_callback_manager = CallbackManager([
for tool in tools:
tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
])
for tool in tools:
tool.callback_manager = tool_callback_manager
]
prompt = cls.build_agent_prompt_template(
tools=tools,
......@@ -54,7 +51,7 @@ class AgentBuilder:
tools=tools,
agent=agent,
memory=memory,
callback_manager=agent_callback_manager,
callbacks=agent_callback_manager,
max_iterations=6,
early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
......
......@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
......@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
......@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = []
self._current_loop = None
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
......@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = []
self._current_loop = None
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
# Final Answer
......
......@@ -3,7 +3,6 @@ import logging
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask
......@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
......@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
) -> None:
"""Do nothing."""
logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass
from llama_index import Response
from typing import List
from langchain.schema import Document
from extensions.ext_database import db
from models.dataset import DocumentSegment
class IndexToolCallbackHandler:
def __init__(self) -> None:
self._response = None
@property
def response(self) -> Response:
return self._response
def on_tool_end(self, response: Response) -> None:
"""Handle tool end."""
self._response = response
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None:
super().__init__()
self.dataset_id = dataset_id
def on_tool_end(self, response: Response) -> None:
def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end."""
for node in response.source_nodes:
index_node_id = node.node.doc_id
for document in documents:
doc_id = document.metadata['doc_id']
# add hit count to document segment
db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == index_node_id
DocumentSegment.index_node_id == doc_id
).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
......
......@@ -3,7 +3,7 @@ import time
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
......@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
conversation_message_task: ConversationMessageTask):
......@@ -25,35 +26,35 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
if 'Chat' in serialized['name']:
real_prompts = []
messages = []
for prompt in prompts:
role, content = prompt.split(': ', maxsplit=1)
if role == 'human':
for message in messages[0]:
if message.type == 'human':
role = 'user'
message = HumanMessage(content=content)
elif role == 'ai':
elif message.type == 'ai':
role = 'assistant'
message = AIMessage(content=content)
else:
message = SystemMessage(content=content)
role = 'system'
real_prompt = {
real_prompts.append({
"role": role,
"text": content
}
real_prompts.append(real_prompt)
messages.append(message)
"text": message.content
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
else:
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
......@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else:
logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
pass
import logging
import time
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
......@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
class MainChainGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
......@@ -50,8 +50,10 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
) -> None:
"""Print out that we are entering a chain."""
if not self._current_chain_result:
chain_type = serialized['id'][-1]
if chain_type:
self._current_chain_result = ChainResult(
type=serialized['name'],
type=chain_type,
prompt=inputs,
started_at=time.perf_counter()
)
......@@ -75,63 +77,3 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
) -> None:
logging.error(error)
self.clear_chain_results()
\ No newline at end of file
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.error(error)
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass
import os
import sys
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
class DifyStdOutCallbackHandler(BaseCallbackHandler):
......@@ -13,16 +14,22 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler."""
self.color = color
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
for sub_messages in messages:
for sub_message in sub_messages:
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue')
if 'Chat' in serialized['name']:
for prompt in prompts:
print_text(prompt + "\n", color='blue')
else:
print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
......@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
......@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent end."""
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
......
from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
......@@ -14,7 +12,7 @@ class ChainBuilder:
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
......@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)
......
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown
......@@ -52,7 +53,8 @@ class LLMRouterChain(Chain):
def _call(
self,
inputs: Dict[str, Any]
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],
......
from typing import Optional, List
from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
......@@ -18,6 +16,7 @@ from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
......@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
......@@ -42,9 +42,8 @@ class MainChainBuilder:
return None
for chain in chains:
# do not add handler into singleton callback manager
if not isinstance(chain.callback_manager, SharedCallbackManager):
chain.callback_manager.add_handler(chain_callback_handler)
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
......@@ -57,7 +56,9 @@ class MainChainBuilder:
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
......@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)
......
import math
from typing import Mapping, List, Dict, Any, Optional
from langchain import LLMChain, PromptTemplate, ConversationChain
from langchain.callbacks import CallbackManager
from langchain import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
......@@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_tool_builder import DatasetToolBuilder
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from models.dataset import Dataset
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset, DatasetProcessRule
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
......@@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain):
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to."""
class Config:
......@@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain):
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {}
for dataset in datasets:
dataset_tool = DatasetToolBuilder.build_dataset_tool(
# fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset,
response_mode='no_synthesizer', # "compact"
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
if dataset_tool:
dataset_tools[dataset.id] = dataset_tool
dataset_tools[str(dataset.id)] = dataset_tool
return cls(
router_chain=router_chain,
......@@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain):
**kwargs,
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call(
self,
inputs: Dict[str, Any]
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if len(self.dataset_tools) == 0:
return {"text": ''}
......
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
......@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return self.canned_response
return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.tools import BaseTool
......@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)
......
import logging
from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
......@@ -34,8 +35,6 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
memory = None
if conversation:
# get memory of conversation (read-only)
......@@ -48,6 +47,14 @@ class Completion:
inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
......@@ -64,6 +71,7 @@ class Completion:
main_chain = MainChainBuilder.to_langchain_components(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
)
......@@ -115,7 +123,7 @@ class Completion:
memory=memory
)
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=final_llm,
......@@ -247,16 +255,14 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:']
@classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
return [llm_callback_handler, DifyStdOutCallbackHandler()]
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
......@@ -293,7 +299,8 @@ And answer according to the language of the user's question.
return memory
@classmethod
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
......@@ -302,8 +309,26 @@ And answer according to the language of the user's question.
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
raise LLMBadRequestError("Query is too long")
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
......@@ -360,7 +385,7 @@ And answer according to the language of the user's question.
streaming=streaming
)
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=llm,
......
......@@ -293,12 +293,12 @@ class PubHandler:
if not user:
raise ValueError("user is required")
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result:{}-{}".format(user_str, task_id)
@classmethod
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result_stopped:{}-{}".format(user_str, task_id)
def pub_text(self, text: str):
......@@ -306,10 +306,10 @@ class PubHandler:
'event': 'message',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
'conversation_id': str(self._conversation.id)
}
}
......
import tempfile
from pathlib import Path
from typing import List, Union
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from extensions.ext_storage import storage
from models.model import UploadFile
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
delimiter = '\n'
if input_file.suffix == '.xlsx':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()
import logging
from typing import Optional, Dict, List
from langchain.document_loaders import CSVLoader as LCCSVLoader
from langchain.document_loaders.helpers import detect_file_encodings
from models.dataset import Document
logger = logging.getLogger(__name__)
class CSVLoader(LCCSVLoader):
def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
self.file_path = file_path
self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
"""Load data into document objects."""
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def _read_from_file(self, csvfile):
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
import json
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from openpyxl.reader.excel import load_workbook
logger = logging.getLogger(__name__)
class ExcelLoader(BaseLoader):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
data = []
keys = []
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, row))
row_dict = {k: v for k, v in row_dict.items() if v}
data.append(json.dumps(row_dict, ensure_ascii=False))
return [Document(page_content='\n\n'.join(data))]
import logging
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class HTMLLoader(BaseLoader):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
"""Markdown parser.
import logging
import re
from typing import Optional, List, Tuple, cast
Contains parser for md files.
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class MarkdownLoader(BaseLoader):
"""Load md files.
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from llama_index.readers.file.base_parser import BaseParser
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
class MarkdownParser(BaseParser):
"""Markdown parser.
remove_images: Whether to remove images from the text.
Extract text from markdown files.
Returns dictionary with keys as headers and values as the text between headers.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def __init__(
self,
*args: Any,
file_path: str,
remove_hyperlinks: bool = True,
remove_images: bool = True,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
"""Initialize with file path."""
self._file_path = file_path
self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
tups = self.parse_tups(self._file_path)
documents = []
for header, value in tups:
value = value.strip()
if header is None:
documents.append(Document(page_content=value))
else:
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
return documents
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary.
......@@ -79,33 +103,32 @@ class MarkdownParser(BaseParser):
content = re.sub(pattern, r"\1", content)
return content
def _init_parser(self) -> Dict:
"""Initialize the parser with the config."""
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
) -> List[Tuple[Optional[str], str]]:
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples."""
with open(filepath, "r", encoding="utf-8") as f:
content = ""
try:
with open(filepath, "r", encoding=self._encoding) as f:
content = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {filepath}") from e
except Exception as e:
raise RuntimeError(f"Error loading {filepath}") from e
if self._remove_hyperlinks:
content = self.remove_hyperlinks(content)
if self._remove_images:
content = self.remove_images(content)
markdown_tups = self.markdown_to_tups(content)
return markdown_tups
def parse_file(
self, filepath: Path, errors: str = "ignore"
) -> Union[str, List[str]]:
"""Parse file into string."""
tups = self.parse_tups(filepath, errors=errors)
results = []
# TODO: don't include headers right now
for header, value in tups:
if header is None:
results.append(value)
else:
results.append(f"\n\n{header}\n{value}")
return results
return self.markdown_to_tups(content)
import logging
from typing import List, Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> List[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents
from typing import Any, Dict, Optional, Sequence
import tiktoken
from llama_index.data_structs import Node
from llama_index.docstore.types import BaseDocumentStore
from llama_index.docstore.utils import json_to_doc
from llama_index.schema import BaseDocument
from langchain.schema import Document
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
......@@ -12,7 +8,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore(BaseDocumentStore):
class DatesetDocumentStore:
def __init__(
self,
dataset: Dataset,
......@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return self._embedding_model_name
@property
def docs(self) -> Dict[str, BaseDocument]:
def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()
......@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output = {}
for document_segment in document_segments:
doc_id = document_segment.index_node_id
result = self.segment_to_dict(document_segment)
output[doc_id] = json_to_doc(result)
output[doc_id] = Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
return output
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id
......@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position = 0
for doc in docs:
if doc.is_doc_id_none:
raise ValueError("doc_id not set")
if not isinstance(doc, Node):
raise ValueError("doc must be a Node")
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
# NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document:
raise ValueError(
f"doc_id {doc.get_doc_id()} already exists. "
f"doc_id {doc.metadata['doc_id']} already exists. "
"Set allow_update to True to overwrite."
)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
if not segment_document:
max_position += 1
......@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
index_node_id=doc.get_doc_id(),
index_node_hash=doc.get_doc_hash(),
index_node_id=doc.metadata['doc_id'],
index_node_hash=doc.metadata['doc_hash'],
position=max_position,
content=doc.get_text(),
word_count=len(doc.get_text()),
content=doc.page_content,
word_count=len(doc.page_content),
tokens=tokens,
created_by=self._user_id,
)
db.session.add(segment_document)
else:
segment_document.content = doc.get_text()
segment_document.index_node_hash = doc.get_doc_hash()
segment_document.word_count = len(doc.get_text())
segment_document.content = doc.page_content
segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
db.session.commit()
......@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
......@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else:
return None
result = self.segment_to_dict(document_segment)
return json_to_doc(result)
return Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
document_segment = self.get_document_segment(doc_id)
......@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return document_segment.index_node_hash
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
......@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
).first()
return document_segment
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
return {
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"text": segment.content,
"__type__": Node.get_type()
}
from typing import Any, Dict, Optional, Sequence
from llama_index.docstore.types import BaseDocumentStore
from llama_index.schema import BaseDocument
class EmptyDocumentStore(BaseDocumentStore):
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {}
@property
def docs(self) -> Dict[str, BaseDocument]:
return {}
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
pass
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return False
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
return None
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
pass
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
pass
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
return None
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
import logging
from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings = []
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
i += 1
text_embeddings.extend(embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to db')
return embedding_results
from typing import Optional, Any, List
import openai
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
_TEXT_MODE_MODEL_DICT
from tenacity import wait_random_exponential, retry, stop_after_attempt
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
text: str,
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[float]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding"
]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
class OpenAIEmbedding(BaseEmbedding):
def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
new_kwargs = {}
if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']
super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name
self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')
@handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(self._get_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(await self._aget_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
from __future__ import annotations
from abc import abstractmethod, ABC
from typing import List, Any
from langchain.schema import Document, BaseRetriever
from models.dataset import Dataset
class BaseIndex(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError
@abstractmethod
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError
@abstractmethod
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
\ No newline at end of file
from langchain.callbacks import CallbackManager
from llama_index import ServiceContext, PromptHelper, LLMPredictor
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.embedding.openai_embedding import OpenAIEmbedding
from core.llm.llm_builder import LLMBuilder
class IndexBuilder:
@classmethod
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
# set number of output tokens
num_output = 512
# only for verbose
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='text-davinci-003',
temperature=0,
max_tokens=num_output,
callback_manager=callback_manager,
)
llm_predictor = LLMPredictor(llm=llm)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper = PromptHelper(
max_input_size=3500,
num_output=num_output,
max_chunk_overlap=20
)
provider = LLMBuilder.get_default_provider(tenant_id)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002'
)
return ServiceContext.from_defaults(
llm_predictor=llm_predictor,
prompt_helper=prompt_helper,
embed_model=OpenAIEmbedding(**model_credentials),
)
@classmethod
def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='fake'
)
return ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
import re
from typing import (
Any,
Dict,
List,
Set,
Optional
)
import jieba.analyse
from core.index.keyword_table.stopwords import STOPWORDS
from llama_index.indices.query.base import IS
from llama_index import QueryMode
from llama_index.indices.base import QueryMap
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
def jieba_extract_keywords(
text_chunk: str,
max_keywords: Optional[int] = None,
expand_with_subtokens: bool = True,
) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text_chunk,
topK=max_keywords,
)
if expand_with_subtokens:
return set(expand_tokens_with_subtokens(keywords))
else:
return set(keywords)
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def _extract_keywords(self, text: str) -> Set[str]:
"""Extract keywords from text."""
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
@classmethod
def get_query_map(self) -> QueryMap:
"""Get query map."""
super_map = super().get_query_map()
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
return super_map
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete = {doc_id}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in self._index_struct.table.items():
if node_idxs_to_delete.intersection(node_idxs):
self._index_struct.table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not self._index_struct.table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del self._index_struct.table[keyword]
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)
import json
from typing import List, Optional
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
from llama_index.data_structs import KeywordTable, Node
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.registry import load_index_struct_from_dict
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.index_builder import IndexBuilder
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
index_struct = KeywordTable()
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node in nodes:
keywords = index._extract_keywords(node.get_text())
self.update_segment_keywords(node.doc_id, list(keywords))
index._index_struct.add_node(list(keywords), node)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
def del_nodes(self, node_ids: List[str]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node_id in node_ids:
index.delete(node_id)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
@property
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
docstore = DatesetDocumentStore(
dataset=self._dataset,
user_id=self._dataset.created_by,
embedding_model_name="text-embedding-ada-002"
)
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return None
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
def get_keyword_table(self):
dataset_keyword_table = self._dataset.dataset_keyword_table
if dataset_keyword_table:
return dataset_keyword_table
return None
def update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
import re
from typing import Set
import jieba
from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
\ No newline at end of file
import json
from collections import defaultdict
from typing import Any, List, Optional, Dict
from langchain.schema import Document, BaseRetriever
from pydantic import BaseModel, Field, Extra
from core.index.base import BaseIndex
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
super().__init__(dataset)
self._config = config
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
return documents
def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
}
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
else:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del keyword_table[keyword]
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
list(chunk_indices_count.keys()),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
from typing import (
Any,
Dict,
Optional, Sequence,
)
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from llama_index.types import RESPONSE_TEXT_TYPE
class EnhanceResponseSynthesizer(ResponseSynthesizer):
@classmethod
def from_args(
cls,
service_context: ServiceContext,
streaming: bool = False,
use_async: bool = False,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_mode: ResponseMode = ResponseMode.DEFAULT,
response_kwargs: Optional[Dict] = None,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
) -> "ResponseSynthesizer":
response_builder: Optional[BaseResponseBuilder] = None
if response_mode != ResponseMode.NO_TEXT:
if response_mode == 'no_synthesizer':
response_builder = NoSynthesizer(
service_context=service_context,
simple_template=simple_template,
streaming=streaming,
)
else:
response_builder = get_response_builder(
service_context,
text_qa_template,
refine_template,
simple_template,
response_mode,
use_async=use_async,
streaming=streaming,
)
return cls(response_builder, response_mode, response_kwargs, optimizer)
class NoSynthesizer(BaseResponseBuilder):
def __init__(
self,
service_context: ServiceContext,
simple_template: Optional[SimpleInputPrompt] = None,
streaming: bool = False,
) -> None:
super().__init__(service_context, streaming)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
\ No newline at end of file
from pathlib import Path
from typing import Dict
from bs4 import BeautifulSoup
from llama_index.readers.file.base_parser import BaseParser
class HTMLParser(BaseParser):
"""HTML parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
with open(file, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text
from pathlib import Path
from typing import Dict
from flask import current_app
from llama_index.readers.file.base_parser import BaseParser
from pypdf import PdfReader
from extensions.ext_storage import storage
from models.model import UploadFile
class PDFParser(BaseParser):
"""PDF parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
if not current_app.config.get('PDF_PREVIEW', True):
return ''
plaintext_file_key = ''
plaintext_file_exists = False
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
upload_file: UploadFile = self._parser_config['upload_file']
if upload_file.hash:
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return text
except FileNotFoundError:
pass
text_list = []
with open(file, "rb") as fp:
# Create a PDF object
pdf = PdfReader(fp)
# Get the number of pages in the PDF document
num_pages = len(pdf.pages)
# Iterate over every page
for page in range(num_pages):
# Extract the text from the page
page_text = pdf.pages[page].extract_text()
text_list.append(page_text)
text = "\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return text
from pathlib import Path
import json
from typing import Dict
from openpyxl import load_workbook
from llama_index.readers.file.base_parser import BaseParser
from flask import current_app
class XLSXParser(BaseParser):
"""XLSX parser."""
def _init_parser(self) -> Dict:
"""Init parser"""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
data = []
keys = []
with open(file, "r") as fp:
wb = load_workbook(filename=file, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, row))
row_dict = {k: v for k, v in row_dict.items() if v}
data.append(json.dumps(row_dict, ensure_ascii=False))
return '\n\n'.join(data)
import json
import logging
from typing import List, Optional
from llama_index.data_structs import Node
from requests import ReadTimeout
from sqlalchemy.exc import IntegrityError
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from core.index.index_builder import IndexBuilder
from core.vector_store.base import BaseGPTVectorStoreIndex
from extensions.ext_vector_store import vector_store
from extensions.ext_database import db
from models.dataset import Dataset, Embedding
class VectorIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
if duplicate_check:
nodes = self._filter_duplicate_nodes(index, nodes)
embedding_queue_nodes = []
embedded_nodes = []
for node in nodes:
node_hash = node.doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
if embedding:
node.embedding = embedding.get_embedding()
embedded_nodes.append(node)
else:
embedding_queue_nodes.append(node)
if embedding_queue_nodes:
embedding_results = index._get_node_embedding_results(
embedding_queue_nodes,
set(),
)
# pre embed nodes for cached embedding
for embedding_result in embedding_results:
node = embedding_result.node
node.embedding = embedding_result.embedding
try:
embedding = Embedding(hash=node.doc_hash)
embedding.set_embedding(node.embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
embedded_nodes.append(node)
self.index_insert_nodes(index, embedded_nodes)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
index.insert_nodes(nodes)
def del_nodes(self, node_ids: List[str]):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
for node_id in node_ids:
self.index_delete_node(index, node_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
index.delete_node(node_id)
def del_doc(self, doc_id: str):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
self.index_delete_doc(index, doc_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
index.delete(doc_id)
@property
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
if not self._dataset.index_struct_dict:
return None
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
return vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
for node in nodes:
node_id = node.doc_id
exists_duplicate_node = index.exists_by_node_id(node_id)
if exists_duplicate_node:
nodes.remove(node)
return nodes
import json
import logging
from abc import abstractmethod
from typing import List, Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.info(f"Recreating dataset {dataset.id}")
try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
origin_index_struct = self.dataset.index_struct
self.dataset.index_struct = None
if documents:
try:
self.create(documents)
except Exception as e:
self.dataset.index_struct = origin_index_struct
raise e
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
import os
from typing import Optional, Any, List, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if class_prefix.startswith('Vector_'):
# original class_prefix
return True
return False
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError(f"Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")
from typing import Optional, cast
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False
This diff is collapsed.
from typing import Union, Optional
from typing import Union, Optional, List
from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
......@@ -32,12 +31,11 @@ class LLMBuilder:
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
if model_name == 'fake':
return FakeListLLM(responses=[])
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == 'openai':
......@@ -52,16 +50,21 @@ class LLMBuilder:
else:
raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
model_kwargs = {
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
}
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
return llm_cls(
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
top_p=kwargs.get('top_p', 1),
frequency_penalty=kwargs.get('frequency_penalty', 0),
presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
......@@ -69,7 +72,7 @@ class LLMBuilder:
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
......@@ -82,7 +85,7 @@ class LLMBuilder:
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callback_manager=callback_manager
callbacks=callbacks
)
@classmethod
......
......@@ -42,6 +42,9 @@ class AzureProvider(BaseProvider):
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
......
import os
from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any
......@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(prompts, stop)
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop)
return await super().agenerate(prompts, stop, callbacks, **kwargs)
This diff is collapsed.
This diff is collapsed.
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
from langchain.schema import get_buffer_string
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from core.vector_store.vector_store import VectorStore
vector_store = VectorStore()
def init_app(app):
vector_store.init_app(app)
This diff is collapsed.
......@@ -38,8 +38,6 @@ class Account(UserMixin, db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
_current_tenant: db.Model = None
@property
def current_tenant(self):
return self._current_tenant
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment