Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
0796791d
Unverified
Commit
0796791d
authored
Aug 26, 2023
by
takatost
Committed by
GitHub
Aug 26, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: hf inference endpoint stream support (#1028)
parent
6c148b22
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
128 additions
and
40 deletions
+128
-40
anthropic_model.py
api/core/model_providers/models/llm/anthropic_model.py
+2
-2
azure_openai_model.py
api/core/model_providers/models/llm/azure_openai_model.py
+3
-3
base.py
api/core/model_providers/models/llm/base.py
+4
-4
chatglm_model.py
api/core/model_providers/models/llm/chatglm_model.py
+0
-4
huggingface_hub_model.py
api/core/model_providers/models/llm/huggingface_hub_model.py
+13
-4
openai_model.py
api/core/model_providers/models/llm/openai_model.py
+2
-2
openllm_model.py
api/core/model_providers/models/llm/openllm_model.py
+0
-4
replicate_model.py
api/core/model_providers/models/llm/replicate_model.py
+3
-3
spark_model.py
api/core/model_providers/models/llm/spark_model.py
+3
-3
tongyi_model.py
api/core/model_providers/models/llm/tongyi_model.py
+2
-2
wenxin_model.py
api/core/model_providers/models/llm/wenxin_model.py
+0
-4
xinference_model.py
api/core/model_providers/models/llm/xinference_model.py
+2
-2
huggingface_endpoint_llm.py
...re/third_party/langchain/llms/huggingface_endpoint_llm.py
+91
-2
test_huggingface_hub_provider.py
...it_tests/model_providers/test_huggingface_hub_provider.py
+3
-1
No files found.
api/core/model_providers/models/llm/anthropic_model.py
View file @
0796791d
...
...
@@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM):
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/models/llm/azure_openai_model.py
View file @
0796791d
...
...
@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
return
True
\ No newline at end of file
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/models/llm/base.py
View file @
0796791d
...
...
@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
result
=
self
.
_run
(
messages
=
messages
,
stop
=
stop
,
callbacks
=
callbacks
if
not
(
self
.
streaming
and
not
self
.
support_streaming
()
)
else
None
,
callbacks
=
callbacks
if
not
(
self
.
streaming
and
not
self
.
support_streaming
)
else
None
,
**
kwargs
)
except
Exception
as
ex
:
...
...
@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
else
:
completion_content
=
result
.
generations
[
0
][
0
]
.
text
if
self
.
streaming
and
not
self
.
support_streaming
()
:
if
self
.
streaming
and
not
self
.
support_streaming
:
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts
=
self
.
_get_prompt_from_messages
(
messages
,
ModelMode
.
CHAT
)
fake_llm
=
FakeLLM
(
...
...
@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
else
:
self
.
client
.
callbacks
.
extend
(
callbacks
)
@
classmethod
def
support_streaming
(
cls
):
@
property
def
support_streaming
(
self
):
return
False
def
get_prompt
(
self
,
mode
:
str
,
...
...
api/core/model_providers/models/llm/chatglm_model.py
View file @
0796791d
...
...
@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
return
LLMBadRequestError
(
f
"ChatGLM: {str(ex)}"
)
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
return
False
api/core/model_providers/models/llm/huggingface_hub_model.py
View file @
0796791d
...
...
@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
if
self
.
credentials
[
'huggingfacehub_api_type'
]
==
'inference_endpoints'
:
streaming
=
self
.
streaming
if
'baichuan'
in
self
.
name
.
lower
():
streaming
=
False
client
=
HuggingFaceEndpointLLM
(
endpoint_url
=
self
.
credentials
[
'huggingfacehub_endpoint_url'
],
task
=
self
.
credentials
[
'task_type'
],
model_kwargs
=
provider_model_kwargs
,
huggingfacehub_api_token
=
self
.
credentials
[
'huggingfacehub_api_token'
],
callbacks
=
self
.
callbacks
callbacks
=
self
.
callbacks
,
streaming
=
streaming
)
else
:
client
=
HuggingFaceHub
(
...
...
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Huggingface Hub: {str(ex)}"
)
@
classmethod
def
support_streaming
(
cls
):
return
False
@
property
def
support_streaming
(
self
):
if
self
.
credentials
[
'huggingfacehub_api_type'
]
==
'inference_endpoints'
:
if
'baichuan'
in
self
.
name
.
lower
():
return
False
return
True
api/core/model_providers/models/llm/openai_model.py
View file @
0796791d
...
...
@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
@
property
def
support_streaming
(
self
):
return
True
# def is_model_valid_or_raise(self):
...
...
api/core/model_providers/models/llm/openllm_model.py
View file @
0796791d
...
...
@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"OpenLLM: {str(ex)}"
)
@
classmethod
def
support_streaming
(
cls
):
return
False
api/core/model_providers/models/llm/replicate_model.py
View file @
0796791d
...
...
@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
return
True
\ No newline at end of file
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/models/llm/spark_model.py
View file @
0796791d
...
...
@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
return
True
\ No newline at end of file
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/models/llm/tongyi_model.py
View file @
0796791d
...
...
@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
else
:
return
ex
@
classmethod
def
support_streaming
(
cls
):
@
property
def
support_streaming
(
self
):
return
True
api/core/model_providers/models/llm/wenxin_model.py
View file @
0796791d
...
...
@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Wenxin: {str(ex)}"
)
@
classmethod
def
support_streaming
(
cls
):
return
False
api/core/model_providers/models/llm/xinference_model.py
View file @
0796791d
...
...
@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Xinference: {str(ex)}"
)
@
classmethod
def
support_streaming
(
cls
):
@
property
def
support_streaming
(
self
):
return
True
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py
View file @
0796791d
from
typing
import
Dict
from
typing
import
Dict
,
Any
,
Optional
,
List
,
Iterable
,
Iterator
from
huggingface_hub
import
InferenceClient
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.embeddings.huggingface_hub
import
VALID_TASKS
from
langchain.llms
import
HuggingFaceEndpoint
from
pydantic
import
Extra
,
root_validator
from
langchain.llms.utils
import
enforce_stop_tokens
from
pydantic
import
root_validator
from
langchain.utils
import
get_from_dict_or_env
...
...
@@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
huggingfacehub_api_token="my-api-key"
)
"""
client
:
Any
streaming
:
bool
=
False
@
root_validator
(
allow_reuse
=
True
)
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
...
...
@@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
values
,
"huggingfacehub_api_token"
,
"HUGGINGFACEHUB_API_TOKEN"
)
values
[
'client'
]
=
InferenceClient
(
values
[
'endpoint_url'
],
token
=
huggingfacehub_api_token
)
values
[
"huggingfacehub_api_token"
]
=
huggingfacehub_api_token
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = hf("Tell me a joke.")
"""
_model_kwargs
=
self
.
model_kwargs
or
{}
# payload samples
params
=
{
**
_model_kwargs
,
**
kwargs
}
# generation parameter
gen_kwargs
=
{
**
params
,
'stop_sequences'
:
stop
}
response
=
self
.
client
.
text_generation
(
prompt
,
stream
=
self
.
streaming
,
details
=
True
,
**
gen_kwargs
)
if
self
.
streaming
and
isinstance
(
response
,
Iterable
):
combined_text_output
=
""
for
token
in
self
.
_stream_response
(
response
,
run_manager
):
combined_text_output
+=
token
completion
=
combined_text_output
else
:
completion
=
response
.
generated_text
if
self
.
task
==
"text-generation"
:
text
=
completion
# Remove prompt if included in generated text.
if
text
.
startswith
(
prompt
):
text
=
text
[
len
(
prompt
)
:]
elif
self
.
task
==
"text2text-generation"
:
text
=
completion
else
:
raise
ValueError
(
f
"Got invalid task {self.task}, "
f
"currently only {VALID_TASKS} are supported"
)
if
stop
is
not
None
:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text
=
enforce_stop_tokens
(
text
,
stop
)
return
text
def
_stream_response
(
self
,
response
:
Iterable
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
Iterator
[
str
]:
for
r
in
response
:
# skip special tokens
if
r
.
token
.
special
:
continue
token
=
r
.
token
.
text
if
run_manager
:
run_manager
.
on_llm_new_token
(
token
=
token
,
verbose
=
self
.
verbose
,
log_probs
=
None
)
# yield the generated token
yield
token
api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py
View file @
0796791d
...
...
@@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i
def
test_inference_endpoints_is_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'huggingface_hub.hf_api.HfApi.whoami'
,
return_value
=
None
)
mocker
.
patch
(
'
langchain.llms.huggingface_endpoint.HuggingFaceEndpoint
._call'
,
return_value
=
"abc"
)
mocker
.
patch
(
'
core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM
._call'
,
return_value
=
"abc"
)
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
model_name
=
'test_model_name'
,
...
...
@@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
credentials
=
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
)
def
test_inference_endpoints_is_credentials_valid_or_raise_invalid
(
mocker
):
mocker
.
patch
(
'huggingface_hub.hf_api.HfApi.whoami'
,
return_value
=
None
)
mocker
.
patch
(
'core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call'
,
return_value
=
"abc"
)
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment