Unverified Commit 91348497 authored by Yeuoly's avatar Yeuoly Committed by GitHub

fix: remove tiktoken from text splitter (#1876)

parent fcf85129
...@@ -5,12 +5,12 @@ import re ...@@ -5,12 +5,12 @@ import re
import threading import threading
import time import time
import uuid import uuid
from typing import Optional, List, cast from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
from flask import current_app, Flask from flask import current_app, Flask
from flask_login import current_user from flask_login import current_user
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor from core.data_loader.file_extractor import FileExtractor
...@@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError ...@@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
...@@ -502,7 +503,8 @@ class IndexingRunner: ...@@ -502,7 +503,8 @@ class IndexingRunner:
if separator: if separator:
separator = separator.replace('\\n', '\n') separator = separator.replace('\\n', '\n')
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=segmentation["max_tokens"], chunk_size=segmentation["max_tokens"],
chunk_overlap=0, chunk_overlap=0,
fixed_separator=separator, fixed_separator=separator,
...@@ -510,7 +512,7 @@ class IndexingRunner: ...@@ -510,7 +512,7 @@ class IndexingRunner:
) )
else: else:
# Automatic segmentation # Automatic segmentation
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0, chunk_overlap=0,
separators=["\n\n", "。", ".", " ", ""] separators=["\n\n", "。", ".", " ", ""]
......
...@@ -7,10 +7,38 @@ from typing import ( ...@@ -7,10 +7,38 @@ from typing import (
Optional, Optional,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
"""
@classmethod
def from_gpt2_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
return GPT2Tokenizer.get_num_tokens(text)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_token_encoder, **kwargs)
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter.""" """Create a new TextSplitter."""
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): ...@@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
if _good_splits: if _good_splits:
merged_text = self._merge_splits(_good_splits, separator) merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text) final_chunks.extend(merged_text)
return final_chunks return final_chunks
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment