Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
a76fde3d
Unverified
Commit
a76fde3d
authored
Aug 23, 2023
by
takatost
Committed by
GitHub
Aug 23, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: optimize hf inference endpoint (#975)
parent
1fc57d73
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
11 deletions
+59
-11
huggingface_hub_model.py
api/core/model_providers/models/llm/huggingface_hub_model.py
+5
-7
huggingface_hub_provider.py
...ore/model_providers/providers/huggingface_hub_provider.py
+13
-3
huggingface_endpoint_llm.py
...re/third_party/langchain/llms/huggingface_endpoint_llm.py
+39
-0
test_huggingface_hub_provider.py
...it_tests/model_providers/test_huggingface_hub_provider.py
+2
-1
No files found.
api/core/model_providers/models/llm/huggingface_hub_model.py
View file @
a76fde3d
import
decimal
from
functools
import
wraps
from
typing
import
List
,
Optional
,
Any
from
typing
import
List
,
Optional
,
Any
from
langchain
import
HuggingFaceHub
from
langchain
import
HuggingFaceHub
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
HuggingFaceEndpoint
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.huggingface_endpoint_llm
import
HuggingFaceEndpointLLM
class
HuggingfaceHubModel
(
BaseLLM
):
class
HuggingfaceHubModel
(
BaseLLM
):
...
@@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
...
@@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
def
_init_client
(
self
)
->
Any
:
def
_init_client
(
self
)
->
Any
:
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
provider_model_kwargs
=
self
.
_to_model_kwargs_input
(
self
.
model_rules
,
self
.
model_kwargs
)
if
self
.
credentials
[
'huggingfacehub_api_type'
]
==
'inference_endpoints'
:
if
self
.
credentials
[
'huggingfacehub_api_type'
]
==
'inference_endpoints'
:
client
=
HuggingFaceEndpoint
(
client
=
HuggingFaceEndpoint
LLM
(
endpoint_url
=
self
.
credentials
[
'huggingfacehub_endpoint_url'
],
endpoint_url
=
self
.
credentials
[
'huggingfacehub_endpoint_url'
],
task
=
'text2text-generation'
,
task
=
self
.
credentials
[
'task_type'
]
,
model_kwargs
=
provider_model_kwargs
,
model_kwargs
=
provider_model_kwargs
,
huggingfacehub_api_token
=
self
.
credentials
[
'huggingfacehub_api_token'
],
huggingfacehub_api_token
=
self
.
credentials
[
'huggingfacehub_api_token'
],
callbacks
=
self
.
callbacks
,
callbacks
=
self
.
callbacks
)
)
else
:
else
:
client
=
HuggingFaceHub
(
client
=
HuggingFaceHub
(
...
...
api/core/model_providers/providers/huggingface_hub_provider.py
View file @
a76fde3d
...
@@ -2,7 +2,6 @@ import json
...
@@ -2,7 +2,6 @@ import json
from
typing
import
Type
from
typing
import
Type
from
huggingface_hub
import
HfApi
from
huggingface_hub
import
HfApi
from
langchain.llms
import
HuggingFaceEndpoint
from
core.helper
import
encrypter
from
core.helper
import
encrypter
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
...
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
...
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.base
import
BaseProviderModel
from
core.third_party.langchain.llms.huggingface_endpoint_llm
import
HuggingFaceEndpointLLM
from
models.provider
import
ProviderType
from
models.provider
import
ProviderType
...
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
...
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
if
'huggingfacehub_endpoint_url'
not
in
credentials
:
if
'huggingfacehub_endpoint_url'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Hugging Face Hub Endpoint URL must be provided.'
)
raise
CredentialsValidateFailedError
(
'Hugging Face Hub Endpoint URL must be provided.'
)
if
'task_type'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Task Type must be provided.'
)
if
credentials
[
'task_type'
]
not
in
(
"text2text-generation"
,
"text-generation"
,
"summarization"
):
raise
CredentialsValidateFailedError
(
'Task Type must be one of text2text-generation, text-generation, summarization.'
)
try
:
try
:
llm
=
HuggingFaceEndpoint
(
llm
=
HuggingFaceEndpoint
LLM
(
endpoint_url
=
credentials
[
'huggingfacehub_endpoint_url'
],
endpoint_url
=
credentials
[
'huggingfacehub_endpoint_url'
],
task
=
"text2text-generation"
,
task
=
credentials
[
'task_type'
]
,
model_kwargs
=
{
"temperature"
:
0.5
,
"max_new_tokens"
:
200
},
model_kwargs
=
{
"temperature"
:
0.5
,
"max_new_tokens"
:
200
},
huggingfacehub_api_token
=
credentials
[
'huggingfacehub_api_token'
]
huggingfacehub_api_token
=
credentials
[
'huggingfacehub_api_token'
]
)
)
...
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
...
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
}
}
credentials
=
json
.
loads
(
provider_model
.
encrypted_config
)
credentials
=
json
.
loads
(
provider_model
.
encrypted_config
)
if
'task_type'
not
in
credentials
:
credentials
[
'task_type'
]
=
'text-generation'
if
credentials
[
'huggingfacehub_api_token'
]:
if
credentials
[
'huggingfacehub_api_token'
]:
credentials
[
'huggingfacehub_api_token'
]
=
encrypter
.
decrypt_token
(
credentials
[
'huggingfacehub_api_token'
]
=
encrypter
.
decrypt_token
(
self
.
provider
.
tenant_id
,
self
.
provider
.
tenant_id
,
...
...
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py
0 → 100644
View file @
a76fde3d
from
typing
import
Dict
from
langchain.llms
import
HuggingFaceEndpoint
from
pydantic
import
Extra
,
root_validator
from
langchain.utils
import
get_from_dict_or_env
class
HuggingFaceEndpointLLM
(
HuggingFaceEndpoint
):
"""HuggingFace Endpoint models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation` and `text2text-generation` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceEndpoint
endpoint_url = (
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
)
hf = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
huggingfacehub_api_token="my-api-key"
)
"""
@
root_validator
(
allow_reuse
=
True
)
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token
=
get_from_dict_or_env
(
values
,
"huggingfacehub_api_token"
,
"HUGGINGFACEHUB_API_TOKEN"
)
values
[
"huggingfacehub_api_token"
]
=
huggingfacehub_api_token
return
values
api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py
View file @
a76fde3d
...
@@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
...
@@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
=
{
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
=
{
'huggingfacehub_api_type'
:
'inference_endpoints'
,
'huggingfacehub_api_type'
:
'inference_endpoints'
,
'huggingfacehub_api_token'
:
'valid_key'
,
'huggingfacehub_api_token'
:
'valid_key'
,
'huggingfacehub_endpoint_url'
:
'valid_url'
'huggingfacehub_endpoint_url'
:
'valid_url'
,
'task_type'
:
'text-generation'
}
}
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
def
encrypt_side_effect
(
tenant_id
,
encrypt_key
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment