Unverified Commit 0796791d authored by takatost's avatar takatost Committed by GitHub

feat: hf inference endpoint stream support (#1028)

parent 6c148b22
...@@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM): ...@@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM):
else: else:
return ex return ex
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return True return True
...@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM): ...@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
else: else:
return ex return ex
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return True return True
\ No newline at end of file
...@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel): ...@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
result = self._run( result = self._run(
messages=messages, messages=messages,
stop=stop, stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None, callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
**kwargs **kwargs
) )
except Exception as ex: except Exception as ex:
...@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel): ...@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
else: else:
completion_content = result.generations[0][0].text completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming(): if self.streaming and not self.support_streaming:
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM( fake_llm = FakeLLM(
...@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel): ...@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
else: else:
self.client.callbacks.extend(callbacks) self.client.callbacks.extend(callbacks)
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return False return False
def get_prompt(self, mode: str, def get_prompt(self, mode: str,
......
...@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM): ...@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
return LLMBadRequestError(f"ChatGLM: {str(ex)}") return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else: else:
return ex return ex
@classmethod
def support_streaming(cls):
return False
...@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM): ...@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
streaming = self.streaming
if 'baichuan' in self.name.lower():
streaming = False
client = HuggingFaceEndpointLLM( client = HuggingFaceEndpointLLM(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'], endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task=self.credentials['task_type'], task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs, model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks callbacks=self.callbacks,
streaming=streaming
) )
else: else:
client = HuggingFaceHub( client = HuggingFaceHub(
...@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM): ...@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}") return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return False if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'baichuan' in self.name.lower():
return False
return True
...@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM): ...@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
else: else:
return ex return ex
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return True return True
# def is_model_valid_or_raise(self): # def is_model_valid_or_raise(self):
......
...@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM): ...@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM: {str(ex)}") return LLMBadRequestError(f"OpenLLM: {str(ex)}")
@classmethod
def support_streaming(cls):
return False
...@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM): ...@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
else: else:
return ex return ex
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return True return True
\ No newline at end of file
...@@ -65,6 +65,6 @@ class SparkModel(BaseLLM): ...@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
else: else:
return ex return ex
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return True return True
\ No newline at end of file
...@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM): ...@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
else: else:
return ex return ex
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return True return True
...@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM): ...@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}") return LLMBadRequestError(f"Wenxin: {str(ex)}")
@classmethod
def support_streaming(cls):
return False
...@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM): ...@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference: {str(ex)}") return LLMBadRequestError(f"Xinference: {str(ex)}")
@classmethod @property
def support_streaming(cls): def support_streaming(self):
return True return True
from typing import Dict from typing import Dict, Any, Optional, List, Iterable, Iterator
from huggingface_hub import InferenceClient
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.embeddings.huggingface_hub import VALID_TASKS
from langchain.llms import HuggingFaceEndpoint from langchain.llms import HuggingFaceEndpoint
from pydantic import Extra, root_validator from langchain.llms.utils import enforce_stop_tokens
from pydantic import root_validator
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
...@@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint): ...@@ -27,6 +31,8 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
huggingfacehub_api_token="my-api-key" huggingfacehub_api_token="my-api-key"
) )
""" """
client: Any
streaming: bool = False
@root_validator(allow_reuse=True) @root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
...@@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint): ...@@ -35,5 +41,88 @@ class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
) )
values['client'] = InferenceClient(values['endpoint_url'], token=huggingfacehub_api_token)
values["huggingfacehub_api_token"] = huggingfacehub_api_token values["huggingfacehub_api_token"] = huggingfacehub_api_token
return values return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
# payload samples
params = {**_model_kwargs, **kwargs}
# generation parameter
gen_kwargs = {
**params,
'stop_sequences': stop
}
response = self.client.text_generation(prompt, stream=self.streaming, details=True, **gen_kwargs)
if self.streaming and isinstance(response, Iterable):
combined_text_output = ""
for token in self._stream_response(response, run_manager):
combined_text_output += token
completion = combined_text_output
else:
completion = response.generated_text
if self.task == "text-generation":
text = completion
# Remove prompt if included in generated text.
if text.startswith(prompt):
text = text[len(prompt) :]
elif self.task == "text2text-generation":
text = completion
else:
raise ValueError(
f"Got invalid task {self.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
def _stream_response(
self,
response: Iterable,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> Iterator[str]:
for r in response:
# skip special tokens
if r.token.special:
continue
token = r.token.text
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=None
)
# yield the generated token
yield token
...@@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i ...@@ -63,7 +63,7 @@ def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_i
def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker): def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None) mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc") mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name', model_name='test_model_name',
...@@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker): ...@@ -71,8 +71,10 @@ def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker):
credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL
) )
def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker): def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None) mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None)
mocker.patch('core.third_party.langchain.llms.huggingface_endpoint_llm.HuggingFaceEndpointLLM._call', return_value="abc")
with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment