Unverified Commit 91ea6fe4 authored by Jyong's avatar Jyong Committed by GitHub

Fix/langchain document schema (#2539)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 769be131
from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment from models.dataset import DatasetQuery, DocumentSegment
from models.model import DatasetRetrieverResource from models.model import DatasetRetrieverResource
......
...@@ -9,7 +9,6 @@ from typing import Optional, cast ...@@ -9,7 +9,6 @@ from typing import Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from flask_login import current_user from flask_login import current_user
from langchain.text_splitter import TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
from core.docstore.dataset_docstore import DatasetDocumentStore from core.docstore.dataset_docstore import DatasetDocumentStore
...@@ -24,6 +23,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor ...@@ -24,6 +23,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from core.splitter.text_splitter import TextSplitter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
......
from core.rag.models.document import Document
from langchain.schema import Document
class ReorderRunner: class ReorderRunner:
......
...@@ -5,9 +5,9 @@ from typing import Any, Optional ...@@ -5,9 +5,9 @@ from typing import Any, Optional
import requests import requests
from flask import current_app from flask import current_app
from flask_login import current_user from flask_login import current_user
from langchain.schema import Document
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document as DocumentModel from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding from models.source import DataSourceBinding
......
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from langchain.text_splitter import TextSplitter
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from core.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
......
from typing import Optional from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -14,3 +16,64 @@ class Document(BaseModel): ...@@ -14,3 +16,64 @@ class Document(BaseModel):
metadata: Optional[dict] = Field(default_factory=dict) metadata: Optional[dict] = Field(default_factory=dict)
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.
A document transformation system takes a sequence of Documents and returns a
sequence of transformed Documents.
Example:
.. code-block:: python
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95
class Config:
arbitrary_types_allowed = True
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings, stateful_documents
)
included_idxs = _filter_similar_embeddings(
embedded_documents, self.similarity_fn, self.similarity_threshold
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
from typing import Optional from typing import Optional
from langchain.schema import Document
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.models.document import Document
class RerankRunner: class RerankRunner:
......
...@@ -3,7 +3,10 @@ from __future__ import annotations ...@@ -3,7 +3,10 @@ from __future__ import annotations
from typing import Any, Optional, cast from typing import Any, Optional, cast
from langchain.text_splitter import ( from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.splitter.text_splitter import (
TS, TS,
AbstractSet, AbstractSet,
Collection, Collection,
...@@ -14,10 +17,6 @@ from langchain.text_splitter import ( ...@@ -14,10 +17,6 @@ from langchain.text_splitter import (
Union, Union,
) )
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
""" """
......
This diff is collapsed.
...@@ -13,7 +13,6 @@ import requests ...@@ -13,7 +13,6 @@ import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts from langchain.chains.summarize import refine_prompts
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from newspaper import Article from newspaper import Article
...@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain ...@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
......
...@@ -13,7 +13,6 @@ import requests ...@@ -13,7 +13,6 @@ import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts from langchain.chains.summarize import refine_prompts
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from newspaper import Article from newspaper import Article
...@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain ...@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """ FULL_TEMPLATE = """
TITLE: {title} TITLE: {title}
......
...@@ -3,9 +3,9 @@ import time ...@@ -3,9 +3,9 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
......
...@@ -3,10 +3,10 @@ import time ...@@ -3,10 +3,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
......
...@@ -4,10 +4,10 @@ import time ...@@ -4,10 +4,10 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
......
...@@ -3,9 +3,9 @@ import time ...@@ -3,9 +3,9 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from langchain.schema import Document
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment