Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
2d9616c2
Unverified
Commit
2d9616c2
authored
Aug 25, 2023
by
Uranus
Committed by
GitHub
Aug 25, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: xinference last token being ignored (#1013)
parent
915e2652
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
37 deletions
+55
-37
xinference_llm.py
api/core/third_party/langchain/llms/xinference_llm.py
+55
-37
No files found.
api/core/third_party/langchain/llms/xinference_llm.py
View file @
2d9616c2
...
...
@@ -3,17 +3,20 @@ from typing import Optional, List, Any, Union, Generator
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.llms
import
Xinference
from
langchain.llms.utils
import
enforce_stop_tokens
from
xinference.client
import
RESTfulChatglmCppChatModelHandle
,
\
RESTfulChatModelHandle
,
RESTfulGenerateModelHandle
from
xinference.client
import
(
RESTfulChatglmCppChatModelHandle
,
RESTfulChatModelHandle
,
RESTfulGenerateModelHandle
,
)
class
XinferenceLLM
(
Xinference
):
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
"""Call the xinference model and return the output.
...
...
@@ -29,7 +32,9 @@ class XinferenceLLM(Xinference):
model
=
self
.
client
.
get_model
(
self
.
model_uid
)
if
isinstance
(
model
,
RESTfulChatModelHandle
):
generate_config
:
"LlamaCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{})
generate_config
:
"LlamaCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{}
)
if
stop
:
generate_config
[
"stop"
]
=
stop
...
...
@@ -37,10 +42,10 @@ class XinferenceLLM(Xinference):
if
generate_config
and
generate_config
.
get
(
"stream"
):
combined_text_output
=
""
for
token
in
self
.
_stream_generate
(
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
):
combined_text_output
+=
token
return
combined_text_output
...
...
@@ -48,7 +53,9 @@ class XinferenceLLM(Xinference):
completion
=
model
.
chat
(
prompt
=
prompt
,
generate_config
=
generate_config
)
return
completion
[
"choices"
][
0
][
"message"
][
"content"
]
elif
isinstance
(
model
,
RESTfulGenerateModelHandle
):
generate_config
:
"LlamaCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{})
generate_config
:
"LlamaCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{}
)
if
stop
:
generate_config
[
"stop"
]
=
stop
...
...
@@ -56,27 +63,31 @@ class XinferenceLLM(Xinference):
if
generate_config
and
generate_config
.
get
(
"stream"
):
combined_text_output
=
""
for
token
in
self
.
_stream_generate
(
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
):
combined_text_output
+=
token
return
combined_text_output
else
:
completion
=
model
.
generate
(
prompt
=
prompt
,
generate_config
=
generate_config
)
completion
=
model
.
generate
(
prompt
=
prompt
,
generate_config
=
generate_config
)
return
completion
[
"choices"
][
0
][
"text"
]
elif
isinstance
(
model
,
RESTfulChatglmCppChatModelHandle
):
generate_config
:
"ChatglmCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{})
generate_config
:
"ChatglmCppGenerateConfig"
=
kwargs
.
get
(
"generate_config"
,
{}
)
if
generate_config
and
generate_config
.
get
(
"stream"
):
combined_text_output
=
""
for
token
in
self
.
_stream_generate
(
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
model
=
model
,
prompt
=
prompt
,
run_manager
=
run_manager
,
generate_config
=
generate_config
,
):
combined_text_output
+=
token
completion
=
combined_text_output
...
...
@@ -90,12 +101,21 @@ class XinferenceLLM(Xinference):
return
completion
def
_stream_generate
(
self
,
model
:
Union
[
"RESTfulGenerateModelHandle"
,
"RESTfulChatModelHandle"
,
"RESTfulChatglmCppChatModelHandle"
],
prompt
:
str
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
generate_config
:
Optional
[
Union
[
"LlamaCppGenerateConfig"
,
"PytorchGenerateConfig"
,
"ChatglmCppGenerateConfig"
]]
=
None
,
self
,
model
:
Union
[
"RESTfulGenerateModelHandle"
,
"RESTfulChatModelHandle"
,
"RESTfulChatglmCppChatModelHandle"
,
],
prompt
:
str
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
generate_config
:
Optional
[
Union
[
"LlamaCppGenerateConfig"
,
"PytorchGenerateConfig"
,
"ChatglmCppGenerateConfig"
,
]
]
=
None
,
)
->
Generator
[
str
,
None
,
None
]:
"""
Args:
...
...
@@ -108,7 +128,9 @@ class XinferenceLLM(Xinference):
Yields:
A string token.
"""
if
isinstance
(
model
,
(
RESTfulChatModelHandle
,
RESTfulChatglmCppChatModelHandle
)):
if
isinstance
(
model
,
(
RESTfulChatModelHandle
,
RESTfulChatglmCppChatModelHandle
)
):
streaming_response
=
model
.
chat
(
prompt
=
prompt
,
generate_config
=
generate_config
)
...
...
@@ -123,14 +145,10 @@ class XinferenceLLM(Xinference):
if
choices
:
choice
=
choices
[
0
]
if
isinstance
(
choice
,
dict
):
if
'finish_reason'
in
choice
and
choice
[
'finish_reason'
]
\
and
choice
[
'finish_reason'
]
in
[
'stop'
,
'length'
]:
break
if
'text'
in
choice
:
if
"text"
in
choice
:
token
=
choice
.
get
(
"text"
,
""
)
elif
'delta'
in
choice
and
'content'
in
choice
[
'delta'
]:
token
=
choice
.
get
(
'delta'
)
.
get
(
'content'
)
elif
"delta"
in
choice
and
"content"
in
choice
[
"delta"
]:
token
=
choice
.
get
(
"delta"
)
.
get
(
"content"
)
else
:
continue
log_probs
=
choice
.
get
(
"logprobs"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment