Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
91348497
Unverified
Commit
91348497
authored
Jan 03, 2024
by
Yeuoly
Committed by
GitHub
Jan 03, 2024
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: remove tiktoken from text splitter (#1876)
parent
fcf85129
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
8 deletions
+38
-8
indexing_runner.py
api/core/indexing_runner.py
+7
-5
fixed_text_splitter.py
api/core/spiltter/fixed_text_splitter.py
+31
-3
No files found.
api/core/indexing_runner.py
View file @
91348497
...
...
@@ -5,12 +5,12 @@ import re
import
threading
import
time
import
uuid
from
typing
import
Optional
,
List
,
cast
from
typing
import
Optional
,
List
,
cast
,
Type
,
Union
,
Literal
,
AbstractSet
,
Collection
,
Any
from
flask
import
current_app
,
Flask
from
flask_login
import
current_user
from
langchain.schema
import
Document
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
,
TextSplitter
from
langchain.text_splitter
import
TextSplitter
,
TS
,
Token
TextSplitter
from
sqlalchemy.orm.exc
import
ObjectDeletedError
from
core.data_loader.file_extractor
import
FileExtractor
...
...
@@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
from
core.model_runtime.entities.model_entities
import
ModelType
,
PriceType
from
core.model_runtime.model_providers.__base.large_language_model
import
LargeLanguageModel
from
core.model_runtime.model_providers.__base.text_embedding_model
import
TextEmbeddingModel
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
from
core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier
import
GPT2Tokenizer
from
core.spiltter.fixed_text_splitter
import
FixedRecursiveCharacterTextSplitter
,
EnhanceRecursiveCharacterTextSplitter
from
extensions.ext_database
import
db
from
extensions.ext_redis
import
redis_client
from
extensions.ext_storage
import
storage
...
...
@@ -502,7 +503,8 @@ class IndexingRunner:
if
separator
:
separator
=
separator
.
replace
(
'
\\
n'
,
'
\n
'
)
character_splitter
=
FixedRecursiveCharacterTextSplitter
.
from_tiktoken_encoder
(
character_splitter
=
FixedRecursiveCharacterTextSplitter
.
from_gpt2_encoder
(
chunk_size
=
segmentation
[
"max_tokens"
],
chunk_overlap
=
0
,
fixed_separator
=
separator
,
...
...
@@ -510,7 +512,7 @@ class IndexingRunner:
)
else
:
# Automatic segmentation
character_splitter
=
RecursiveCharacterTextSplitter
.
from_tiktoken
_encoder
(
character_splitter
=
EnhanceRecursiveCharacterTextSplitter
.
from_gpt2
_encoder
(
chunk_size
=
DatasetProcessRule
.
AUTOMATIC_RULES
[
'segmentation'
][
'max_tokens'
],
chunk_overlap
=
0
,
separators
=
[
"
\n\n
"
,
"。"
,
"."
,
" "
,
""
]
...
...
api/core/spiltter/fixed_text_splitter.py
View file @
91348497
...
...
@@ -7,10 +7,38 @@ from typing import (
Optional
,
)
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
from
langchain.text_splitter
import
RecursiveCharacterTextSplitter
,
TokenTextSplitter
,
TS
,
Type
,
Union
,
AbstractSet
,
Literal
,
Collection
from
core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier
import
GPT2Tokenizer
class
FixedRecursiveCharacterTextSplitter
(
RecursiveCharacterTextSplitter
):
class
EnhanceRecursiveCharacterTextSplitter
(
RecursiveCharacterTextSplitter
):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
"""
@
classmethod
def
from_gpt2_encoder
(
cls
:
Type
[
TS
],
encoding_name
:
str
=
"gpt2"
,
model_name
:
Optional
[
str
]
=
None
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]]
=
set
(),
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
**
kwargs
:
Any
,
):
def
_token_encoder
(
text
:
str
)
->
int
:
return
GPT2Tokenizer
.
get_num_tokens
(
text
)
if
issubclass
(
cls
,
TokenTextSplitter
):
extra_kwargs
=
{
"encoding_name"
:
encoding_name
,
"model_name"
:
model_name
,
"allowed_special"
:
allowed_special
,
"disallowed_special"
:
disallowed_special
,
}
kwargs
=
{
**
kwargs
,
**
extra_kwargs
}
return
cls
(
length_function
=
_token_encoder
,
**
kwargs
)
class
FixedRecursiveCharacterTextSplitter
(
EnhanceRecursiveCharacterTextSplitter
):
def
__init__
(
self
,
fixed_separator
:
str
=
"
\n\n
"
,
separators
:
Optional
[
List
[
str
]]
=
None
,
**
kwargs
:
Any
):
"""Create a new TextSplitter."""
super
()
.
__init__
(
**
kwargs
)
...
...
@@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
if
_good_splits
:
merged_text
=
self
.
_merge_splits
(
_good_splits
,
separator
)
final_chunks
.
extend
(
merged_text
)
return
final_chunks
return
final_chunks
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment