Unverified Commit 91ea6fe4 authored by Jyong's avatar Jyong Committed by GitHub

Fix/langchain document schema (#2539)

Co-authored-by: 's avatarjyong <jyong@dify.ai>
parent 769be131
from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment
from models.model import DatasetRetrieverResource
......
......@@ -9,7 +9,6 @@ from typing import Optional, cast
from flask import Flask, current_app
from flask_login import current_user
from langchain.text_splitter import TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.docstore.dataset_docstore import DatasetDocumentStore
......@@ -24,6 +23,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from core.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
......
from langchain.schema import Document
from core.rag.models.document import Document
class ReorderRunner:
......
......@@ -5,9 +5,9 @@ from typing import Any, Optional
import requests
from flask import current_app
from flask_login import current_user
from langchain.schema import Document
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding
......
......@@ -2,12 +2,11 @@
from abc import ABC, abstractmethod
from typing import Optional
from langchain.text_splitter import TextSplitter
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from core.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
......
from typing import Optional
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field
......@@ -14,3 +16,64 @@ class Document(BaseModel):
metadata: Optional[dict] = Field(default_factory=dict)
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.
A document transformation system takes a sequence of Documents and returns a
sequence of transformed Documents.
Example:
.. code-block:: python
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95
class Config:
arbitrary_types_allowed = True
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings, stateful_documents
)
included_idxs = _filter_similar_embeddings(
embedded_documents, self.similarity_fn, self.similarity_threshold
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
from typing import Optional
from langchain.schema import Document
from core.model_manager import ModelInstance
from core.rag.models.document import Document
class RerankRunner:
......
......@@ -3,7 +3,10 @@ from __future__ import annotations
from typing import Any, Optional, cast
from langchain.text_splitter import (
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.splitter.text_splitter import (
TS,
AbstractSet,
Collection,
......@@ -14,10 +17,6 @@ from langchain.text_splitter import (
Union,
)
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
......
This diff is collapsed.
......@@ -13,7 +13,6 @@ import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article
......@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """
TITLE: {title}
......
......@@ -13,7 +13,6 @@ import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from langchain.chains import RefineDocumentsChain
from langchain.chains.summarize import refine_prompts
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.base import BaseTool
from newspaper import Article
......@@ -24,6 +23,7 @@ from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.rag.extractor import extract_processor
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.models.document import Document
FULL_TEMPLATE = """
TITLE: {title}
......
......@@ -3,9 +3,9 @@ import time
import click
from celery import shared_task
from langchain.schema import Document
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
......
......@@ -3,10 +3,10 @@ import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
......
......@@ -4,10 +4,10 @@ import time
import click
from celery import shared_task
from langchain.schema import Document
from werkzeug.exceptions import NotFound
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
......
......@@ -3,9 +3,9 @@ import time
import click
from celery import shared_task
from langchain.schema import Document
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment