Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
6c832ee3
Unverified
Commit
6c832ee3
authored
Aug 20, 2023
by
takatost
Committed by
GitHub
Aug 20, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: remove openllm pypi package because of this package too large (#931)
parent
25264e78
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
97 additions
and
13 deletions
+97
-13
openllm_model.py
api/core/model_providers/models/llm/openllm_model.py
+2
-2
openllm_provider.py
api/core/model_providers/providers/openllm_provider.py
+6
-5
openllm.py
api/core/third_party/langchain/llms/openllm.py
+87
-0
requirements.txt
api/requirements.txt
+1
-2
test_openllm_provider.py
...tests/unit_tests/model_providers/test_openllm_provider.py
+1
-4
No files found.
api/core/model_providers/models/llm/openllm_model.py
View file @
6c832ee3
from
typing
import
List
,
Optional
,
Any
from
langchain.callbacks.manager
import
Callbacks
from
langchain.llms
import
OpenLLM
from
langchain.schema
import
LLMResult
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.entity.model_params
import
ModelMode
,
ModelKwargs
from
core.third_party.langchain.llms.openllm
import
OpenLLM
class
OpenLLMModel
(
BaseLLM
):
...
...
@@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM):
client
=
OpenLLM
(
server_url
=
self
.
credentials
.
get
(
'server_url'
),
callbacks
=
self
.
callbacks
,
**
self
.
provider_model_kwargs
llm_kwargs
=
self
.
provider_model_kwargs
)
return
client
...
...
api/core/model_providers/providers/openllm_provider.py
View file @
6c832ee3
import
json
from
typing
import
Type
from
langchain.llms
import
OpenLLM
from
core.helper
import
encrypter
from
core.model_providers.models.entity.model_params
import
KwargRule
,
ModelKwargsRules
,
ModelType
from
core.model_providers.models.llm.openllm_model
import
OpenLLMModel
from
core.model_providers.providers.base
import
BaseModelProvider
,
CredentialsValidateFailedError
from
core.model_providers.models.base
import
BaseProviderModel
from
core.third_party.langchain.llms.openllm
import
OpenLLM
from
models.provider
import
ProviderType
...
...
@@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
"""
return
ModelKwargsRules
(
temperature
=
KwargRule
[
float
](
min
=
0
,
max
=
2
,
default
=
1
),
temperature
=
KwargRule
[
float
](
min
=
0
.01
,
max
=
2
,
default
=
1
),
top_p
=
KwargRule
[
float
](
min
=
0
,
max
=
1
,
default
=
0.7
),
presence_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
frequency_penalty
=
KwargRule
[
float
](
min
=-
2
,
max
=
2
,
default
=
0
),
max_tokens
=
KwargRule
[
int
](
min
=
10
,
max
=
4000
,
default
=
128
),
max_tokens
=
KwargRule
[
int
](
alias
=
'max_new_tokens'
,
min
=
10
,
max
=
4000
,
default
=
128
),
)
@
classmethod
...
...
@@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
}
llm
=
OpenLLM
(
max_tokens
=
10
,
llm_kwargs
=
{
'max_new_tokens'
:
10
},
**
credential_kwargs
)
...
...
api/core/third_party/langchain/llms/openllm.py
0 → 100644
View file @
6c832ee3
from
__future__
import
annotations
import
logging
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
)
import
requests
from
langchain.llms.utils
import
enforce_stop_tokens
from
pydantic
import
Field
from
langchain.callbacks.manager
import
(
AsyncCallbackManagerForLLMRun
,
CallbackManagerForLLMRun
,
)
from
langchain.llms.base
import
LLM
logger
=
logging
.
getLogger
(
__name__
)
class
OpenLLM
(
LLM
):
"""OpenLLM, supporting both in-process model
instance and remote OpenLLM servers.
If you have a OpenLLM server running, you can also use it remotely:
.. code-block:: python
from langchain.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000')
llm("What is the difference between a duck and a goose?")
"""
server_url
:
Optional
[
str
]
=
None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
llm_kwargs
:
Dict
[
str
,
Any
]
=
Field
(
default_factory
=
dict
)
"""Key word arguments to be passed to openllm.LLM"""
@
property
def
_llm_type
(
self
)
->
str
:
return
"openllm"
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
CallbackManagerForLLMRun
|
None
=
None
,
**
kwargs
:
Any
,
)
->
str
:
params
=
{
"prompt"
:
prompt
,
"llm_config"
:
self
.
llm_kwargs
}
headers
=
{
"Content-Type"
:
"application/json"
}
response
=
requests
.
post
(
f
'{self.server_url}/v1/generate'
,
headers
=
headers
,
json
=
params
)
if
not
response
.
ok
:
raise
ValueError
(
f
"OpenLLM HTTP {response.status_code} error: {response.text}"
)
json_response
=
response
.
json
()
completion
=
json_response
[
"responses"
][
0
]
if
completion
:
completion
=
completion
[
len
(
prompt
):]
if
stop
is
not
None
:
completion
=
enforce_stop_tokens
(
completion
,
stop
)
return
completion
async
def
_acall
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
raise
NotImplementedError
(
"Async call is not supported for OpenLLM at the moment."
)
api/requirements.txt
View file @
6c832ee3
...
...
@@ -49,5 +49,4 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.0
openllm~=0.2.26
\ No newline at end of file
xinference==0.2.0
\ No newline at end of file
api/tests/unit_tests/model_providers/test_openllm_provider.py
View file @
6c832ee3
...
...
@@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def
test_is_credentials_valid_or_raise_valid
(
mocker
):
mocker
.
patch
(
'langchain.llms.openllm.OpenLLM._identifying_params'
,
return_value
=
None
)
mocker
.
patch
(
'langchain.llms.openllm.OpenLLM._call'
,
mocker
.
patch
(
'core.third_party.langchain.llms.openllm.OpenLLM._call'
,
return_value
=
"abc"
)
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
...
...
@@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker):
def
test_is_credentials_valid_or_raise_invalid
(
mocker
):
mocker
.
patch
(
'langchain.llms.openllm.OpenLLM._identifying_params'
,
return_value
=
None
)
# raise CredentialsValidateFailedError if credential is not in credentials
with
pytest
.
raises
(
CredentialsValidateFailedError
):
MODEL_PROVIDER_CLASS
.
is_model_credentials_valid_or_raise
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment