Unverified Commit 6c832ee3 authored by takatost's avatar takatost Committed by GitHub

fix: remove openllm pypi package because of this package too large (#931)

parent 25264e78
from typing import List, Optional, Any from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import OpenLLM
from langchain.schema import LLMResult from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.openllm import OpenLLM
class OpenLLMModel(BaseLLM): class OpenLLMModel(BaseLLM):
...@@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM): ...@@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM):
client = OpenLLM( client = OpenLLM(
server_url=self.credentials.get('server_url'), server_url=self.credentials.get('server_url'),
callbacks=self.callbacks, callbacks=self.callbacks,
**self.provider_model_kwargs llm_kwargs=self.provider_model_kwargs
) )
return client return client
......
import json import json
from typing import Type from typing import Type
from langchain.llms import OpenLLM
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType from models.provider import ProviderType
...@@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider): ...@@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return: :return:
""" """
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1), temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7), top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0), presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0), frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=128), max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
) )
@classmethod @classmethod
...@@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider): ...@@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
} }
llm = OpenLLM( llm = OpenLLM(
max_tokens=10, llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs **credential_kwargs
) )
......
from __future__ import annotations
import logging
from typing import (
Any,
Dict,
List,
Optional,
)
import requests
from langchain.llms.utils import enforce_stop_tokens
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class OpenLLM(LLM):
"""OpenLLM, supporting both in-process model
instance and remote OpenLLM servers.
If you have a OpenLLM server running, you can also use it remotely:
.. code-block:: python
from langchain.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000')
llm("What is the difference between a duck and a goose?")
"""
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to be passed to openllm.LLM"""
@property
def _llm_type(self) -> str:
return "openllm"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> str:
params = {
"prompt": prompt,
"llm_config": self.llm_kwargs
}
headers = {"Content-Type": "application/json"}
response = requests.post(
f'{self.server_url}/v1/generate',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
json_response = response.json()
completion = json_response["responses"][0]
if completion:
completion = completion[len(prompt):]
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
raise NotImplementedError(
"Async call is not supported for OpenLLM at the moment."
)
...@@ -50,4 +50,3 @@ transformers~=4.31.0 ...@@ -50,4 +50,3 @@ transformers~=4.31.0
stripe~=5.5.0 stripe~=5.5.0
pandas==1.5.3 pandas==1.5.3
xinference==0.2.0 xinference==0.2.0
openllm~=0.2.26
\ No newline at end of file
...@@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key): ...@@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_credentials_valid_or_raise_valid(mocker): def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call',
mocker.patch('langchain.llms.openllm.OpenLLM._call',
return_value="abc") return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
...@@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker): ...@@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker):
def test_is_credentials_valid_or_raise_invalid(mocker): def test_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
# raise CredentialsValidateFailedError if credential is not in credentials # raise CredentialsValidateFailedError if credential is not in credentials
with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment