Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
a76fde3d
Unverified
Commit
a76fde3d
authored
Aug 23, 2023
by
takatost
Committed by
GitHub
Aug 23, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize hf inference endpoint (#975)
parent
1fc57d73
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
11 deletions
+59
-11
huggingface_hub_model.py
api/core/model_providers/models/llm/huggingface_hub_model.py
+5
-7
huggingface_hub_provider.py
...ore/model_providers/providers/huggingface_hub_provider.py
+13
-3
huggingface_endpoint_llm.py
...re/third_party/langchain/llms/huggingface_endpoint_llm.py
+39
-0
test_huggingface_hub_provider.py
...it_tests/model_providers/test_huggingface_hub_provider.py
+2
-1
No files found.
api/core/model_providers/models/llm/huggingface_hub_model.py
View file @
a76fde3d
import
decimal
from
functools
import
wraps
from
typing
import
List
,
Optional
,
Any
from
langchain
import
HuggingFaceHub
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
HuggingFaceEndpoint
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.huggingface_endpoint_llm
import
HuggingFaceEndpointLLM
class
HuggingfaceHubModel
(
BaseLLM
):
...
...
@@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
if
self
.
credentials
[
'huggingfacehub_api_type'
]
==
'inference_endpoints'
:
client
=
HuggingFaceEndpoint
(
client
=
HuggingFaceEndpoint
LLM
(
endpoint_url
=
self
.
credentials
[
'huggingfacehub_endpoint_url'
],
task
=
'text2text-generation'
,
task
=
self
.
credentials
[
'task_type'
]
,
model_kwargs
=
provider_model_kwargs
,
huggingfacehub_api_token
=
self
.
credentials
[
'huggingfacehub_api_token'
],
callbacks
=
self
.
callbacks
,
callbacks
=
self
.
callbacks
)
else
:
client
=
HuggingFaceHub
(
...
...
api/core/model_providers/providers/huggingface_hub_provider.py
View file @
a76fde3d
...
...
@@ -2,7 +2,6 @@ import json
from
typing
import
Type
from
huggingface_hub
import
HfApi
from
langchain.llms
import
HuggingFaceEndpoint
from
core.helper
import
encrypter
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
...
...
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
core.third_party.langchain.llms.huggingface_endpoint_llm
import
HuggingFaceEndpointLLM
from
models.provider
import
ProviderType
...
...
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
if
'huggingfacehub_endpoint_url'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Hugging Face Hub Endpoint URL must be provided.'
)
if
'task_type'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Task Type must be provided.'
)
if
credentials
[
'task_type'
]
not
in
(
"text2text-generation"
,
"text-generation"
,
"summarization"
):
raise
CredentialsValidateFailedError
(
'Task Type must be one of text2text-generation, text-generation, summarization.'
)
try
:
llm
=
HuggingFaceEndpoint
(
llm
=
HuggingFaceEndpoint
LLM
(
endpoint_url
=
credentials
[
'huggingfacehub_endpoint_url'
],
task
=
"text2text-generation"
,
task
=
credentials
[
'task_type'
]
,
model_kwargs
=
{
"temperature"
:
0.5
,
"max_new_tokens"
:
200
},
huggingfacehub_api_token
=
credentials
[
'huggingfacehub_api_token'
]
)
...
...
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
}
credentials
=
json
.
loads
(
provider_model
.
encrypted_config
)
if
'task_type'
not
in
credentials
:
credentials
[
'task_type'
]
=
'text-generation'
if
credentials
[
'huggingfacehub_api_token'
]:
credentials
[
'huggingfacehub_api_token'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
...
...
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py
0 → 100644
View file @
a76fde3d
from
typing
import
Dict
from
langchain.llms
import
HuggingFaceEndpoint
from
pydantic
import
Extra
,
root_validator
from
langchain.utils
import
get_from_dict_or_env
class
HuggingFaceEndpointLLM
(
HuggingFaceEndpoint
):
"""HuggingFace Endpoint models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation` and `text2text-generation` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceEndpoint
endpoint_url = (
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
)
hf = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
huggingfacehub_api_token="my-api-key"
)
"""
@
root_validator
(
allow_reuse
=
True
)
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token
=
get_from_dict_or_env
(
values
,
"huggingfacehub_api_token"
,
"HUGGINGFACEHUB_API_TOKEN"
)
values
[
"huggingfacehub_api_token"
]
=
huggingfacehub_api_token
return
values
api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py
View file @
a76fde3d
...
...
@@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
=
{
'huggingfacehub_api_type'
:
'inference_endpoints'
,
'huggingfacehub_api_token'
:
'valid_key'
,
'huggingfacehub_endpoint_url'
:
'valid_url'
'huggingfacehub_endpoint_url'
:
'valid_url'
,
'task_type'
:
'text-generation'
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment