Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
e409895c
Unverified
Commit
e409895c
authored
Sep 22, 2023
by
Garfield Dai
Committed by
GitHub
Sep 22, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat/huggingface embedding support (#1211)
Co-authored-by:
StyleZhang
<
jasonapring2015@outlook.com
>
parent
32d9b618
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
416 additions
and
28 deletions
+416
-28
huggingface_embedding.py
...model_providers/models/embedding/huggingface_embedding.py
+22
-0
huggingface_hub_provider.py
...ore/model_providers/providers/huggingface_hub_provider.py
+66
-12
huggingface_hub_embedding.py
...d_party/langchain/embeddings/huggingface_hub_embedding.py
+74
-0
huggingface_hub_llm.py
api/core/third_party/langchain/llms/huggingface_hub_llm.py
+1
-1
requirements.txt
api/requirements.txt
+1
-1
.env.example
api/tests/integration_tests/.env.example
+1
-0
test_huggingface_hub_embedding.py
..._tests/models/embedding/test_huggingface_hub_embedding.py
+136
-0
huggingface_hub.tsx
...er/account-setting/model-page/configs/huggingface_hub.tsx
+89
-9
Form.tsx
...ts/header/account-setting/model-page/model-modal/Form.tsx
+25
-4
index.tsx
...s/header/account-setting/model-page/model-modal/index.tsx
+1
-1
No files found.
api/core/model_providers/models/embedding/huggingface_embedding.py
0 → 100644
View file @
e409895c
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.third_party.langchain.embeddings.huggingface_hub_embedding
import
HuggingfaceHubEmbeddings
from
core.model_providers.models.embedding.base
import
BaseEmbedding
class
HuggingfaceEmbedding
(
BaseEmbedding
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
credentials
=
model_provider
.
get_model_credentials
(
model_name
=
name
,
model_type
=
self
.
type
)
client
=
HuggingfaceHubEmbeddings
(
model
=
name
,
**
credentials
)
super
()
.
__init__
(
model_provider
,
client
,
name
)
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
return
LLMBadRequestError
(
f
"Huggingface embedding: {str(ex)}"
)
api/core/model_providers/providers/huggingface_hub_provider.py
View file @
e409895c
import
json
from
typing
import
Type
import
requests
from
huggingface_hub
import
HfApi
...
...
@@ -10,8 +11,12 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
from
core.model_providers.models.base
import
BaseProviderModel
from
core.third_party.langchain.llms.huggingface_endpoint_llm
import
HuggingFaceEndpointLLM
from
core.third_party.langchain.embeddings.huggingface_hub_embedding
import
HuggingfaceHubEmbeddings
from
core.model_providers.models.embedding.huggingface_embedding
import
HuggingfaceEmbedding
from
models.provider
import
ProviderType
HUGGINGFACE_ENDPOINT_API
=
'https://api.endpoints.huggingface.cloud/v2/endpoint/'
class
HuggingfaceHubProvider
(
BaseModelProvider
):
@
property
...
...
@@ -33,6 +38,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
"""
if
model_type
==
ModelType
.
TEXT_GENERATION
:
model_class
=
HuggingfaceHubModel
elif
model_type
==
ModelType
.
EMBEDDINGS
:
model_class
=
HuggingfaceEmbedding
else
:
raise
NotImplementedError
...
...
@@ -63,7 +70,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
:param model_type:
:param credentials:
"""
if
model_type
!=
ModelType
.
TEXT_GENERATION
:
if
model_type
not
in
[
ModelType
.
TEXT_GENERATION
,
ModelType
.
EMBEDDINGS
]
:
raise
NotImplementedError
if
'huggingfacehub_api_type'
not
in
credentials
\
...
...
@@ -88,19 +95,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
if
'task_type'
not
in
credentials
:
raise
CredentialsValidateFailedError
(
'Task Type must be provided.'
)
if
credentials
[
'task_type'
]
not
in
(
"text2text-generation"
,
"text-generation"
,
"summarization"
):
if
credentials
[
'task_type'
]
not
in
(
"text2text-generation"
,
"text-generation"
,
'feature-extraction'
):
raise
CredentialsValidateFailedError
(
'Task Type must be one of text2text-generation, '
'text-generation,
summariza
tion.'
)
'text-generation,
feature-extrac
tion.'
)
try
:
llm
=
HuggingFaceEndpointLLM
(
endpoint_url
=
credentials
[
'huggingfacehub_endpoint_url'
],
task
=
credentials
[
'task_type'
],
model_kwargs
=
{
"temperature"
:
0.5
,
"max_new_tokens"
:
200
},
huggingfacehub_api_token
=
credentials
[
'huggingfacehub_api_token'
]
)
llm
(
"ping"
)
if
credentials
[
'task_type'
]
==
'feature-extraction'
:
cls
.
check_embedding_valid
(
credentials
,
model_name
)
else
:
cls
.
check_llm_valid
(
credentials
)
except
Exception
as
e
:
raise
CredentialsValidateFailedError
(
f
"{e.__class__.__name__}:{str(e)}"
)
else
:
...
...
@@ -112,13 +115,64 @@ class HuggingfaceHubProvider(BaseModelProvider):
if
'inference'
in
model_info
.
cardData
and
not
model_info
.
cardData
[
'inference'
]:
raise
ValueError
(
f
'Inference API has been turned off for this model {model_name}.'
)
VALID_TASKS
=
(
"text2text-generation"
,
"text-generation"
,
"
summariza
tion"
)
VALID_TASKS
=
(
"text2text-generation"
,
"text-generation"
,
"
feature-extrac
tion"
)
if
model_info
.
pipeline_tag
not
in
VALID_TASKS
:
raise
ValueError
(
f
"Model {model_name} is not a valid task, "
f
"must be one of {VALID_TASKS}."
)
except
Exception
as
e
:
raise
CredentialsValidateFailedError
(
f
"{e.__class__.__name__}:{str(e)}"
)
@
classmethod
def
check_llm_valid
(
cls
,
credentials
:
dict
):
llm
=
HuggingFaceEndpointLLM
(
endpoint_url
=
credentials
[
'huggingfacehub_endpoint_url'
],
task
=
credentials
[
'task_type'
],
model_kwargs
=
{
"temperature"
:
0.5
,
"max_new_tokens"
:
200
},
huggingfacehub_api_token
=
credentials
[
'huggingfacehub_api_token'
]
)
llm
(
"ping"
)
@
classmethod
def
check_embedding_valid
(
cls
,
credentials
:
dict
,
model_name
:
str
):
cls
.
check_endpoint_url_model_repository_name
(
credentials
,
model_name
)
embedding_model
=
HuggingfaceHubEmbeddings
(
model
=
model_name
,
**
credentials
)
embedding_model
.
embed_query
(
"ping"
)
@
classmethod
def
check_endpoint_url_model_repository_name
(
cls
,
credentials
:
dict
,
model_name
:
str
):
try
:
url
=
f
'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
headers
=
{
'Authorization'
:
f
'Bearer {credentials["huggingfacehub_api_token"]}'
,
'Content-Type'
:
'application/json'
}
response
=
requests
.
get
(
url
=
url
,
headers
=
headers
)
if
response
.
status_code
!=
200
:
raise
ValueError
(
'User Name or Organization Name is invalid.'
)
model_repository_name
=
''
for
item
in
response
.
json
()
.
get
(
"items"
,
[]):
if
item
.
get
(
"status"
,
{})
.
get
(
"url"
)
==
credentials
[
'huggingfacehub_endpoint_url'
]:
model_repository_name
=
item
.
get
(
"model"
,
{})
.
get
(
"repository"
)
break
if
model_repository_name
!=
model_name
:
raise
ValueError
(
f
'Model Name {model_name} is invalid. Please check it on the inference endpoints console.'
)
except
Exception
as
e
:
raise
ValueError
(
str
(
e
))
@
classmethod
def
encrypt_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
model_type
:
ModelType
,
credentials
:
dict
)
->
dict
:
...
...
api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py
0 → 100644
View file @
e409895c
from
typing
import
Any
,
Dict
,
List
,
Optional
import
json
import
numpy
as
np
from
pydantic
import
BaseModel
,
Extra
,
root_validator
from
langchain.embeddings.base
import
Embeddings
from
langchain.utils
import
get_from_dict_or_env
from
huggingface_hub
import
InferenceClient
HOSTED_INFERENCE_API
=
'hosted_inference_api'
INFERENCE_ENDPOINTS
=
'inference_endpoints'
class
HuggingfaceHubEmbeddings
(
BaseModel
,
Embeddings
):
client
:
Any
model
:
str
huggingface_namespace
:
Optional
[
str
]
=
None
task_type
:
Optional
[
str
]
=
None
huggingfacehub_api_type
:
Optional
[
str
]
=
None
huggingfacehub_api_token
:
Optional
[
str
]
=
None
huggingfacehub_endpoint_url
:
Optional
[
str
]
=
None
class
Config
:
extra
=
Extra
.
forbid
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
values
[
'huggingfacehub_api_token'
]
=
get_from_dict_or_env
(
values
,
"huggingfacehub_api_token"
,
"HUGGINGFACEHUB_API_TOKEN"
)
values
[
'client'
]
=
InferenceClient
(
token
=
values
[
'huggingfacehub_api_token'
])
return
values
def
embed_documents
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
model
=
''
if
self
.
huggingfacehub_api_type
==
HOSTED_INFERENCE_API
:
model
=
self
.
model
else
:
model
=
self
.
huggingfacehub_endpoint_url
output
=
self
.
client
.
post
(
json
=
{
"inputs"
:
texts
,
"options"
:
{
"wait_for_model"
:
False
,
"use_cache"
:
False
}
},
model
=
model
)
embeddings
=
json
.
loads
(
output
.
decode
())
return
self
.
mean_pooling
(
embeddings
)
def
embed_query
(
self
,
text
:
str
)
->
List
[
float
]:
return
self
.
embed_documents
([
text
])[
0
]
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
# Returned values are a list of floats, or a list of list of floats
# (depending on if you sent a string or a list of string,
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
# This should be explained on the model's README.)
def
mean_pooling
(
self
,
embeddings
:
List
)
->
List
[
float
]:
# If automatic reduction by giving model, no need to mean_pooling.
# For example one: List[List[float]]
if
not
isinstance
(
embeddings
[
0
][
0
],
list
):
return
embeddings
# For example two: List[List[List[float]]], need to mean_pooling.
sentence_embeddings
=
[
np
.
mean
(
embedding
[
0
],
axis
=
0
)
.
tolist
()
for
embedding
in
embeddings
]
return
sentence_embeddings
api/core/third_party/langchain/llms/huggingface_hub_llm.py
View file @
e409895c
...
...
@@ -16,7 +16,7 @@ class HuggingFaceHubLLM(HuggingFaceHub):
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation`, `text2text-generation`
and `summarization`
for now.
Only supports `text-generation`, `text2text-generation` for now.
Example:
.. code-block:: python
...
...
api/requirements.txt
View file @
e409895c
api/tests/integration_tests/.env.example
View file @
e409895c
...
...
@@ -14,6 +14,7 @@ REPLICATE_API_TOKEN=
# Hugging Face API Key
HUGGINGFACE_API_KEY=
HUGGINGFACE_ENDPOINT_URL=
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL=
# Minimax Credentials
MINIMAX_API_KEY=
...
...
api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py
0 → 100644
View file @
e409895c
import
json
import
os
from
unittest.mock
import
patch
,
MagicMock
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.models.embedding.huggingface_embedding
import
HuggingfaceEmbedding
from
core.model_providers.providers.huggingface_hub_provider
import
HuggingfaceHubProvider
from
models.provider
import
Provider
,
ProviderType
,
ProviderModel
DEFAULT_MODEL_NAME
=
'obrizum/all-MiniLM-L6-v2'
def
get_mock_provider
():
return
Provider
(
id
=
'provider_id'
,
tenant_id
=
'tenant_id'
,
provider_name
=
'huggingface_hub'
,
provider_type
=
ProviderType
.
CUSTOM
.
value
,
encrypted_config
=
''
,
is_valid
=
True
,
)
def
get_mock_embedding_model
(
model_name
,
huggingfacehub_api_type
,
mocker
):
valid_api_key
=
os
.
environ
[
'HUGGINGFACE_API_KEY'
]
endpoint_url
=
os
.
environ
[
'HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'
]
model_provider
=
HuggingfaceHubProvider
(
provider
=
get_mock_provider
())
credentials
=
{
'huggingfacehub_api_type'
:
huggingfacehub_api_type
,
'huggingfacehub_api_token'
:
valid_api_key
,
'task_type'
:
'feature-extraction'
}
if
huggingfacehub_api_type
==
'inference_endpoints'
:
credentials
[
'huggingfacehub_endpoint_url'
]
=
endpoint_url
mock_query
=
MagicMock
()
mock_query
.
filter
.
return_value
.
first
.
return_value
=
ProviderModel
(
provider_name
=
'huggingface_hub'
,
model_name
=
model_name
,
model_type
=
ModelType
.
EMBEDDINGS
.
value
,
encrypted_config
=
json
.
dumps
(
credentials
),
is_valid
=
True
,
)
mocker
.
patch
(
'extensions.ext_database.db.session.query'
,
return_value
=
mock_query
)
return
HuggingfaceEmbedding
(
model_provider
=
model_provider
,
name
=
model_name
)
def
decrypt_side_effect
(
tenant_id
,
encrypted_api_key
):
return
encrypted_api_key
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_hosted_inference_api_embed_documents
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
DEFAULT_MODEL_NAME
,
'hosted_inference_api'
,
mocker
)
rst
=
embedding_model
.
client
.
embed_documents
([
'test'
,
'test1'
])
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
2
assert
len
(
rst
[
0
])
==
384
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_endpoint_url_inference_api_embed_documents
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
''
,
'inference_endpoints'
,
mocker
)
mocker
.
patch
(
'core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
,
return_value
=
bytes
(
json
.
dumps
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'utf-8'
))
rst
=
embedding_model
.
client
.
embed_documents
([
'test'
,
'test1'
])
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
2
assert
len
(
rst
[
0
])
==
3
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_endpoint_url_inference_api_embed_documents_two
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
''
,
'inference_endpoints'
,
mocker
)
mocker
.
patch
(
'core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
,
return_value
=
bytes
(
json
.
dumps
([[[[
1
,
2
,
3
],[
4
,
5
,
6
],[
7
,
8
,
9
]]],[[[
1
,
2
,
3
],[
4
,
5
,
6
],[
7
,
8
,
9
]]]]),
'utf-8'
))
rst
=
embedding_model
.
client
.
embed_documents
([
'test'
,
'test1'
])
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
2
assert
len
(
rst
[
0
])
==
3
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_hosted_inference_api_embed_query
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
DEFAULT_MODEL_NAME
,
'hosted_inference_api'
,
mocker
)
rst
=
embedding_model
.
client
.
embed_query
(
'test'
)
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
384
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_endpoint_url_inference_api_embed_query
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
''
,
'inference_endpoints'
,
mocker
)
mocker
.
patch
(
'core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
,
return_value
=
bytes
(
json
.
dumps
([[
1
,
2
,
3
]]),
'utf-8'
))
rst
=
embedding_model
.
client
.
embed_query
(
'test'
)
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
3
@
patch
(
'core.helper.encrypter.decrypt_token'
,
side_effect
=
decrypt_side_effect
)
def
test_endpoint_url_inference_api_embed_query_two
(
mock_decrypt
,
mocker
):
embedding_model
=
get_mock_embedding_model
(
''
,
'inference_endpoints'
,
mocker
)
mocker
.
patch
(
'core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
,
return_value
=
bytes
(
json
.
dumps
([[[[
1
,
2
,
3
],[
4
,
5
,
6
],[
7
,
8
,
9
]]]]),
'utf-8'
))
rst
=
embedding_model
.
client
.
embed_query
(
'test'
)
assert
isinstance
(
rst
,
list
)
assert
len
(
rst
)
==
3
\ No newline at end of file
web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx
View file @
e409895c
...
...
@@ -48,6 +48,15 @@ const config: ProviderConfig = {
]
}
if
(
v
?.
huggingfacehub_api_type
===
'inference_endpoints'
)
{
if
(
v
.
model_type
===
'embeddings'
)
{
return
[
'huggingfacehub_api_token'
,
'huggingface_namespace'
,
'model_name'
,
'huggingfacehub_endpoint_url'
,
'task_type'
,
]
}
return
[
'huggingfacehub_api_token'
,
'model_name'
,
...
...
@@ -68,6 +77,18 @@ const config: ProviderConfig = {
]
}
if
(
v
?.
huggingfacehub_api_type
===
'inference_endpoints'
)
{
if
(
v
.
model_type
===
'embeddings'
)
{
filteredKeys
=
[
'huggingfacehub_api_type'
,
'huggingfacehub_api_token'
,
'huggingface_namespace'
,
'model_name'
,
'huggingfacehub_endpoint_url'
,
'task_type'
,
'model_type'
,
]
}
else
{
filteredKeys
=
[
'huggingfacehub_api_type'
,
'huggingfacehub_api_token'
,
...
...
@@ -77,12 +98,38 @@ const config: ProviderConfig = {
'model_type'
,
]
}
}
return
filteredKeys
.
reduce
((
prev
:
FormValue
,
next
:
string
)
=>
{
prev
[
next
]
=
v
?.[
next
]
||
''
return
prev
},
{})
},
fields
:
[
{
type
:
'radio'
,
key
:
'model_type'
,
required
:
true
,
label
:
{
'en'
:
'Model Type'
,
'zh-Hans'
:
'模型类型'
,
},
options
:
[
{
key
:
'text-generation'
,
label
:
{
'en'
:
'Text Generation'
,
'zh-Hans'
:
'文本生成'
,
},
},
{
key
:
'embeddings'
,
label
:
{
'en'
:
'Embeddings'
,
'zh-Hans'
:
'Embeddings'
,
},
},
],
},
{
type
:
'radio'
,
key
:
'huggingfacehub_api_type'
,
...
...
@@ -121,6 +168,20 @@ const config: ProviderConfig = {
'zh-Hans'
:
'在此输入您的 Hugging Face Hub API Token'
,
},
},
{
hidden
:
(
value
?:
FormValue
)
=>
!
(
value
?.
huggingfacehub_api_type
===
'inference_endpoints'
&&
value
?.
model_type
===
'embeddings'
),
type
:
'text'
,
key
:
'huggingface_namespace'
,
required
:
true
,
label
:
{
'en'
:
'User Name / Organization Name'
,
'zh-Hans'
:
'用户名 / 组织名称'
,
},
placeholder
:
{
'en'
:
'Enter your User Name / Organization Name here'
,
'zh-Hans'
:
'在此输入您的用户名 / 组织名称'
,
},
},
{
type
:
'text'
,
key
:
'model_name'
,
...
...
@@ -148,7 +209,7 @@ const config: ProviderConfig = {
},
},
{
hidden
:
(
value
?:
FormValue
)
=>
value
?.
huggingfacehub_api_type
===
'hosted_inference_api'
,
hidden
:
(
value
?:
FormValue
)
=>
value
?.
huggingfacehub_api_type
===
'hosted_inference_api'
||
value
?.
model_type
===
'embeddings'
,
type
:
'radio'
,
key
:
'task_type'
,
required
:
true
,
...
...
@@ -173,6 +234,25 @@ const config: ProviderConfig = {
},
],
},
{
hidden
:
(
value
?:
FormValue
)
=>
!
(
value
?.
huggingfacehub_api_type
===
'inference_endpoints'
&&
value
?.
model_type
===
'embeddings'
),
type
:
'radio'
,
key
:
'task_type'
,
required
:
true
,
label
:
{
'en'
:
'Task'
,
'zh-Hans'
:
'Task'
,
},
options
:
[
{
key
:
'feature-extraction'
,
label
:
{
'en'
:
'Feature Extraction'
,
'zh-Hans'
:
'Feature Extraction'
,
},
},
],
},
],
},
}
...
...
web/app/components/header/account-setting/model-page/model-modal/Form.tsx
View file @
e409895c
import
{
useEffect
,
useState
}
from
'react'
import
type
{
Dispatch
,
FC
,
SetStateAction
}
from
'react'
import
{
useContext
}
from
'use-context-selector'
import
type
{
Field
,
FormValue
,
ProviderConfigModal
}
from
'../declarations'
import
{
type
Field
,
type
FormValue
,
type
ProviderConfigModal
,
ProviderEnum
}
from
'../declarations'
import
{
useValidate
}
from
'../../key-validator/hooks'
import
{
ValidatingTip
}
from
'../../key-validator/ValidateStatus'
import
{
validateModelProviderFn
}
from
'../utils'
...
...
@@ -85,10 +85,31 @@ const Form: FC<FormProps> = ({
}
const
handleFormChange
=
(
k
:
string
,
v
:
string
)
=>
{
if
(
mode
===
'edit'
&&
!
cleared
)
if
(
mode
===
'edit'
&&
!
cleared
)
{
handleClear
({
[
k
]:
v
})
else
handleMultiFormChange
({
...
value
,
[
k
]:
v
},
k
)
}
else
{
const
extraValue
:
Record
<
string
,
string
>
=
{}
if
(
(
(
k
===
'model_type'
&&
v
===
'embeddings'
&&
value
.
huggingfacehub_api_type
===
'inference_endpoints'
)
||
(
k
===
'huggingfacehub_api_type'
&&
v
===
'inference_endpoints'
&&
value
.
model_type
===
'embeddings'
)
)
&&
modelModal
?.
key
===
ProviderEnum
.
huggingface_hub
)
extraValue
.
task_type
=
'feature-extraction'
if
(
(
(
k
===
'model_type'
&&
v
===
'text-generation'
&&
value
.
huggingfacehub_api_type
===
'inference_endpoints'
)
||
(
k
===
'huggingfacehub_api_type'
&&
v
===
'inference_endpoints'
&&
value
.
model_type
===
'text-generation'
)
)
&&
modelModal
?.
key
===
ProviderEnum
.
huggingface_hub
)
extraValue
.
task_type
=
'text-generation'
handleMultiFormChange
({
...
value
,
[
k
]:
v
,
...
extraValue
},
k
)
}
}
const
handleFocus
=
()
=>
{
...
...
web/app/components/header/account-setting/model-page/model-modal/index.tsx
View file @
e409895c
...
...
@@ -92,7 +92,7 @@ const ModelModal: FC<ModelModalProps> = ({
return (
<Portal>
<div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'>
<div className='w-[640px] max-h-
screen
bg-white shadow-xl rounded-2xl overflow-y-auto'>
<div className='w-[640px] max-h-
[calc(100vh-120px)]
bg-white shadow-xl rounded-2xl overflow-y-auto'>
<div className='px-8 pt-8'>
<div className='flex justify-between items-center mb-2'>
<div className='text-xl font-semibold text-gray-900'>{renderTitlePrefix()}</div>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment