Unverified Commit 91348497 authored by Yeuoly's avatar Yeuoly Committed by GitHub

fix: remove tiktoken from text splitter (#1876)

parent fcf85129
......@@ -5,12 +5,12 @@ import re
import threading
import time
import uuid
from typing import Optional, List, cast
from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
from flask import current_app, Flask
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
......@@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
......@@ -502,7 +503,8 @@ class IndexingRunner:
if separator:
separator = separator.replace('\\n', '\n')
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
fixed_separator=separator,
......@@ -510,7 +512,7 @@ class IndexingRunner:
)
else:
# Automatic segmentation
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0,
separators=["\n\n", "。", ".", " ", ""]
......
......@@ -7,10 +7,38 @@ from typing import (
Optional,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
"""
@classmethod
def from_gpt2_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
return GPT2Tokenizer.get_num_tokens(text)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_token_encoder, **kwargs)
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
......@@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks
return final_chunks
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment