Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
c4d8bdc3
Unverified
Commit
c4d8bdc3
authored
Sep 08, 2023
by
takatost
Committed by
GitHub
Sep 08, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: hf hosted inference check (#1128)
parent
681eb1cf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
4 deletions
+69
-4
huggingface_hub_model.py
api/core/model_providers/models/llm/huggingface_hub_model.py
+5
-3
huggingface_hub_provider.py
...ore/model_providers/providers/huggingface_hub_provider.py
+2
-1
huggingface_hub_llm.py
api/core/third_party/langchain/llms/huggingface_hub_llm.py
+62
-0
No files found.
api/core/model_providers/models/llm/huggingface_hub_model.py
View file @
c4d8bdc3
from
typing
import
List
,
Optional
,
Any
from
typing
import
List
,
Optional
,
Any
from
langchain
import
HuggingFaceHub
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
...
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
...
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.huggingface_endpoint_llm
import
HuggingFaceEndpointLLM
from
core.third_party.langchain.llms.huggingface_endpoint_llm
import
HuggingFaceEndpointLLM
from
core.third_party.langchain.llms.huggingface_hub_llm
import
HuggingFaceHubLLM
class
HuggingfaceHubModel
(
BaseLLM
):
class
HuggingfaceHubModel
(
BaseLLM
):
...
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
...
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
streaming
=
streaming
streaming
=
streaming
)
)
else
:
else
:
client
=
HuggingFaceHub
(
client
=
HuggingFaceHub
LLM
(
repo_id
=
self
.
name
,
repo_id
=
self
.
name
,
task
=
self
.
credentials
[
'task_type'
],
task
=
self
.
credentials
[
'task_type'
],
model_kwargs
=
provider_model_kwargs
,
model_kwargs
=
provider_model_kwargs
,
...
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
...
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
if
'baichuan'
in
self
.
name
.
lower
():
if
'baichuan'
in
self
.
name
.
lower
():
return
False
return
False
return
True
return
True
else
:
return
False
api/core/model_providers/providers/huggingface_hub_provider.py
View file @
c4d8bdc3
...
@@ -89,7 +89,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
...
@@ -89,7 +89,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
raise
CredentialsValidateFailedError
(
'Task Type must be provided.'
)
raise
CredentialsValidateFailedError
(
'Task Type must be provided.'
)
if
credentials
[
'task_type'
]
not
in
(
"text2text-generation"
,
"text-generation"
,
"summarization"
):
if
credentials
[
'task_type'
]
not
in
(
"text2text-generation"
,
"text-generation"
,
"summarization"
):
raise
CredentialsValidateFailedError
(
'Task Type must be one of text2text-generation, text-generation, summarization.'
)
raise
CredentialsValidateFailedError
(
'Task Type must be one of text2text-generation, '
'text-generation, summarization.'
)
try
:
try
:
llm
=
HuggingFaceEndpointLLM
(
llm
=
HuggingFaceEndpointLLM
(
...
...
api/core/third_party/langchain/llms/huggingface_hub_llm.py
0 → 100644
View file @
c4d8bdc3
from
typing
import
Dict
,
Optional
,
List
,
Any
from
huggingface_hub
import
HfApi
,
InferenceApi
from
langchain
import
HuggingFaceHub
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.llms.huggingface_hub
import
VALID_TASKS
from
pydantic
import
root_validator
from
langchain.utils
import
get_from_dict_or_env
class
HuggingFaceHubLLM
(
HuggingFaceHub
):
"""HuggingFaceHub models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceHub
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
"""
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token
=
get_from_dict_or_env
(
values
,
"huggingfacehub_api_token"
,
"HUGGINGFACEHUB_API_TOKEN"
)
client
=
InferenceApi
(
repo_id
=
values
[
"repo_id"
],
token
=
huggingfacehub_api_token
,
task
=
values
.
get
(
"task"
),
)
client
.
options
=
{
"wait_for_model"
:
False
,
"use_gpu"
:
False
}
values
[
"client"
]
=
client
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
hfapi
=
HfApi
(
token
=
self
.
huggingfacehub_api_token
)
model_info
=
hfapi
.
model_info
(
repo_id
=
self
.
repo_id
)
if
not
model_info
:
raise
ValueError
(
f
"Model {self.repo_id} not found."
)
if
'inference'
in
model_info
.
cardData
and
not
model_info
.
cardData
[
'inference'
]:
raise
ValueError
(
f
"Inference API has been turned off for this model {self.repo_id}."
)
if
model_info
.
pipeline_tag
not
in
VALID_TASKS
:
raise
ValueError
(
f
"Model {self.repo_id} is not a valid task, "
f
"must be one of {VALID_TASKS}."
)
return
super
()
.
_call
(
prompt
,
stop
,
run_manager
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment