Unverified Commit c4d8bdc3 authored by takatost's avatar takatost Committed by GitHub

fix: hf hosted inference check (#1128)

parent 681eb1cf
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult from langchain.schema import LLMResult
...@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM ...@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
class HuggingfaceHubModel(BaseLLM): class HuggingfaceHubModel(BaseLLM):
...@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM): ...@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
streaming=streaming streaming=streaming
) )
else: else:
client = HuggingFaceHub( client = HuggingFaceHubLLM(
repo_id=self.name, repo_id=self.name,
task=self.credentials['task_type'], task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs, model_kwargs=provider_model_kwargs,
...@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM): ...@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
if 'baichuan' in self.name.lower(): if 'baichuan' in self.name.lower():
return False return False
return True return True
else:
return False
...@@ -89,7 +89,8 @@ class HuggingfaceHubProvider(BaseModelProvider): ...@@ -89,7 +89,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
raise CredentialsValidateFailedError('Task Type must be provided.') raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.') raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization.')
try: try:
llm = HuggingFaceEndpointLLM( llm = HuggingFaceEndpointLLM(
......
from typing import Dict, Optional, List, Any
from huggingface_hub import HfApi, InferenceApi
from langchain import HuggingFaceHub
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.huggingface_hub import VALID_TASKS
from pydantic import root_validator
from langchain.utils import get_from_dict_or_env
class HuggingFaceHubLLM(HuggingFaceHub):
"""HuggingFaceHub models.
To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example:
.. code-block:: python
from langchain.llms import HuggingFaceHub
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
client = InferenceApi(
repo_id=values["repo_id"],
token=huggingfacehub_api_token,
task=values.get("task"),
)
client.options = {"wait_for_model": False, "use_gpu": False}
values["client"] = client
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
hfapi = HfApi(token=self.huggingfacehub_api_token)
model_info = hfapi.model_info(repo_id=self.repo_id)
if not model_info:
raise ValueError(f"Model {self.repo_id} not found.")
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {self.repo_id} is not a valid task, "
f"must be one of {VALID_TASKS}.")
return super()._call(prompt, stop, run_manager, **kwargs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment