Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
f68b05d5
Unverified
Commit
f68b05d5
authored
May 19, 2023
by
John Wang
Committed by
GitHub
May 19, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Feat: support azure openai for temporary (#101)
parent
3b3c604e
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
351 additions
and
110 deletions
+351
-110
config.py
api/config.py
+5
-0
providers.py
api/controllers/console/workspace/providers.py
+29
-15
openai_embedding.py
api/core/embedding/openai_embedding.py
+48
-24
index_builder.py
api/core/index/index_builder.py
+3
-0
llm_builder.py
api/core/llm/llm_builder.py
+34
-9
azure_provider.py
api/core/llm/provider/azure_provider.py
+4
-6
base.py
api/core/llm/provider/base.py
+22
-10
streamable_azure_chat_open_ai.py
api/core/llm/streamable_azure_chat_open_ai.py
+40
-2
streamable_azure_open_ai.py
api/core/llm/streamable_azure_open_ai.py
+64
-0
streamable_chat_open_ai.py
api/core/llm/streamable_chat_open_ai.py
+41
-1
streamable_open_ai.py
api/core/llm/streamable_open_ai.py
+43
-1
index.tsx
...er/account-setting/provider-page/azure-provider/index.tsx
+8
-22
index.tsx
...der/account-setting/provider-page/provider-item/index.tsx
+3
-3
common.en.ts
web/i18n/lang/common.en.ts
+2
-6
common.zh.ts
web/i18n/lang/common.zh.ts
+3
-7
common.ts
web/models/common.ts
+2
-4
No files found.
api/config.py
View file @
f68b05d5
...
@@ -47,6 +47,7 @@ DEFAULTS = {
...
@@ -47,6 +47,7 @@ DEFAULTS = {
'PDF_PREVIEW'
:
'True'
,
'PDF_PREVIEW'
:
'True'
,
'LOG_LEVEL'
:
'INFO'
,
'LOG_LEVEL'
:
'INFO'
,
'DISABLE_PROVIDER_CONFIG_VALIDATION'
:
'False'
,
'DISABLE_PROVIDER_CONFIG_VALIDATION'
:
'False'
,
'DEFAULT_LLM_PROVIDER'
:
'openai'
}
}
...
@@ -181,6 +182,10 @@ class Config:
...
@@ -181,6 +182,10 @@ class Config:
# You could disable it for compatibility with certain OpenAPI providers
# You could disable it for compatibility with certain OpenAPI providers
self
.
DISABLE_PROVIDER_CONFIG_VALIDATION
=
get_bool_env
(
'DISABLE_PROVIDER_CONFIG_VALIDATION'
)
self
.
DISABLE_PROVIDER_CONFIG_VALIDATION
=
get_bool_env
(
'DISABLE_PROVIDER_CONFIG_VALIDATION'
)
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self
.
DEFAULT_LLM_PROVIDER
=
get_env
(
'DEFAULT_LLM_PROVIDER'
)
class
CloudEditionConfig
(
Config
):
class
CloudEditionConfig
(
Config
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
api/controllers/console/workspace/providers.py
View file @
f68b05d5
...
@@ -82,29 +82,33 @@ class ProviderTokenApi(Resource):
...
@@ -82,29 +82,33 @@ class ProviderTokenApi(Resource):
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
not
args
[
'token'
]:
if
args
[
'token'
]:
raise
ValueError
(
'Token is empty'
)
try
:
ProviderService
.
validate_provider_configs
(
try
:
tenant
=
current_user
.
current_tenant
,
ProviderService
.
validate_provider_configs
(
provider_name
=
ProviderName
(
provider
),
configs
=
args
[
'token'
]
)
token_is_valid
=
True
except
ValidateFailedError
:
token_is_valid
=
False
base64_encrypted_token
=
ProviderService
.
get_encrypted_token
(
tenant
=
current_user
.
current_tenant
,
tenant
=
current_user
.
current_tenant
,
provider_name
=
ProviderName
(
provider
),
provider_name
=
ProviderName
(
provider
),
configs
=
args
[
'token'
]
configs
=
args
[
'token'
]
)
)
token_is_valid
=
True
else
:
except
ValidateFailedError
:
base64_encrypted_token
=
None
token_is_valid
=
False
token_is_valid
=
False
tenant
=
current_user
.
current_tenant
tenant
=
current_user
.
current_tenant
base64_encrypted_token
=
ProviderService
.
get_encrypted_token
(
provider_model
=
db
.
session
.
query
(
Provider
)
.
filter
(
tenant
=
current_user
.
current_tenant
,
Provider
.
tenant_id
==
tenant
.
id
,
provider_name
=
ProviderName
(
provider
),
Provider
.
provider_name
==
provider
,
configs
=
args
[
'token'
]
Provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
)
)
.
first
()
provider_model
=
Provider
.
query
.
filter_by
(
tenant_id
=
tenant
.
id
,
provider_name
=
provider
,
provider_type
=
ProviderType
.
CUSTOM
.
value
)
.
first
()
# Only allow updating token for CUSTOM provider type
# Only allow updating token for CUSTOM provider type
if
provider_model
:
if
provider_model
:
...
@@ -117,6 +121,16 @@ class ProviderTokenApi(Resource):
...
@@ -117,6 +121,16 @@ class ProviderTokenApi(Resource):
is_valid
=
token_is_valid
)
is_valid
=
token_is_valid
)
db
.
session
.
add
(
provider_model
)
db
.
session
.
add
(
provider_model
)
if
provider_model
.
is_valid
:
other_providers
=
db
.
session
.
query
(
Provider
)
.
filter
(
Provider
.
tenant_id
==
tenant
.
id
,
Provider
.
provider_name
!=
provider
,
Provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
)
.
all
()
for
other_provider
in
other_providers
:
other_provider
.
is_valid
=
False
db
.
session
.
commit
()
db
.
session
.
commit
()
if
provider
in
[
ProviderName
.
ANTHROPIC
.
value
,
ProviderName
.
AZURE_OPENAI
.
value
,
ProviderName
.
COHERE
.
value
,
if
provider
in
[
ProviderName
.
ANTHROPIC
.
value
,
ProviderName
.
AZURE_OPENAI
.
value
,
ProviderName
.
COHERE
.
value
,
...
...
api/core/embedding/openai_embedding.py
View file @
f68b05d5
...
@@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except
...
@@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embedding
(
def
get_embedding
(
text
:
str
,
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
engine
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
)
->
List
[
float
]:
"""Get embedding.
"""Get embedding.
...
@@ -25,11 +26,12 @@ def get_embedding(
...
@@ -25,11 +26,12 @@ def get_embedding(
"""
"""
text
=
text
.
replace
(
"
\n
"
,
" "
)
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
openai
.
Embedding
.
create
(
input
=
[
text
],
engine
=
engine
,
api_key
=
openai_api_key
)[
"data"
][
0
][
"embedding"
]
return
openai
.
Embedding
.
create
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)[
"data"
][
0
][
"embedding"
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
)
->
List
[
float
]:
async
def
aget_embedding
(
text
:
str
,
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
float
]:
"""Asynchronously get embedding.
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
NOTE: Copied from OpenAI's embedding utils:
...
@@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key
...
@@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key
# replace newlines, which can negatively affect performance.
# replace newlines, which can negatively affect performance.
text
=
text
.
replace
(
"
\n
"
,
" "
)
text
=
text
.
replace
(
"
\n
"
,
" "
)
return
(
await
openai
.
Embedding
.
acreate
(
input
=
[
text
],
engine
=
engine
,
api_key
=
openai_api_key
))[
"data"
][
0
][
return
(
await
openai
.
Embedding
.
acreate
(
input
=
[
text
],
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))[
"data"
][
0
][
"embedding"
"embedding"
]
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
def
get_embeddings
(
def
get_embeddings
(
list_of_text
:
List
[
str
],
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
engine
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
)
->
List
[
List
[
float
]]:
"""Get embeddings.
"""Get embeddings.
...
@@ -67,14 +70,14 @@ def get_embeddings(
...
@@ -67,14 +70,14 @@ def get_embeddings(
# replace newlines, which can negatively affect performance.
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
openai
.
Embedding
.
create
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
openai_api_key
)
.
data
data
=
openai
.
Embedding
.
create
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
)
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
return
[
d
[
"embedding"
]
for
d
in
data
]
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
@
retry
(
reraise
=
True
,
wait
=
wait_random_exponential
(
min
=
1
,
max
=
20
),
stop
=
stop_after_attempt
(
6
))
async
def
aget_embeddings
(
async
def
aget_embeddings
(
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
list_of_text
:
List
[
str
],
engine
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
List
[
List
[
float
]]:
)
->
List
[
List
[
float
]]:
"""Asynchronously get embeddings.
"""Asynchronously get embeddings.
...
@@ -90,7 +93,7 @@ async def aget_embeddings(
...
@@ -90,7 +93,7 @@ async def aget_embeddings(
# replace newlines, which can negatively affect performance.
# replace newlines, which can negatively affect performance.
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
list_of_text
=
[
text
.
replace
(
"
\n
"
,
" "
)
for
text
in
list_of_text
]
data
=
(
await
openai
.
Embedding
.
acreate
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
openai_api_key
))
.
data
data
=
(
await
openai
.
Embedding
.
acreate
(
input
=
list_of_text
,
engine
=
engine
,
api_key
=
api_key
,
**
kwargs
))
.
data
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
data
=
sorted
(
data
,
key
=
lambda
x
:
x
[
"index"
])
# maintain the same order as input.
return
[
d
[
"embedding"
]
for
d
in
data
]
return
[
d
[
"embedding"
]
for
d
in
data
]
...
@@ -98,19 +101,30 @@ async def aget_embeddings(
...
@@ -98,19 +101,30 @@ async def aget_embeddings(
class
OpenAIEmbedding
(
BaseEmbedding
):
class
OpenAIEmbedding
(
BaseEmbedding
):
def
__init__
(
def
__init__
(
self
,
self
,
mode
:
str
=
OpenAIEmbeddingMode
.
TEXT_SEARCH_MODE
,
mode
:
str
=
OpenAIEmbeddingMode
.
TEXT_SEARCH_MODE
,
model
:
str
=
OpenAIEmbeddingModelType
.
TEXT_EMBED_ADA_002
,
model
:
str
=
OpenAIEmbeddingModelType
.
TEXT_EMBED_ADA_002
,
deployment_name
:
Optional
[
str
]
=
None
,
deployment_name
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
,
openai_api_key
:
Optional
[
str
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
None
:
)
->
None
:
"""Init params."""
"""Init params."""
super
()
.
__init__
(
**
kwargs
)
new_kwargs
=
{}
if
'embed_batch_size'
in
kwargs
:
new_kwargs
[
'embed_batch_size'
]
=
kwargs
[
'embed_batch_size'
]
if
'tokenizer'
in
kwargs
:
new_kwargs
[
'tokenizer'
]
=
kwargs
[
'tokenizer'
]
super
()
.
__init__
(
**
new_kwargs
)
self
.
mode
=
OpenAIEmbeddingMode
(
mode
)
self
.
mode
=
OpenAIEmbeddingMode
(
mode
)
self
.
model
=
OpenAIEmbeddingModelType
(
model
)
self
.
model
=
OpenAIEmbeddingModelType
(
model
)
self
.
deployment_name
=
deployment_name
self
.
deployment_name
=
deployment_name
self
.
openai_api_key
=
openai_api_key
self
.
openai_api_key
=
openai_api_key
self
.
openai_api_type
=
kwargs
.
get
(
'openai_api_type'
)
self
.
openai_api_version
=
kwargs
.
get
(
'openai_api_version'
)
self
.
openai_api_base
=
kwargs
.
get
(
'openai_api_base'
)
@
handle_llm_exceptions
@
handle_llm_exceptions
def
_get_query_embedding
(
self
,
query
:
str
)
->
List
[
float
]:
def
_get_query_embedding
(
self
,
query
:
str
)
->
List
[
float
]:
...
@@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding):
...
@@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding):
if
key
not
in
_QUERY_MODE_MODEL_DICT
:
if
key
not
in
_QUERY_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_QUERY_MODE_MODEL_DICT
[
key
]
engine
=
_QUERY_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
query
,
engine
=
engine
,
openai_api_key
=
self
.
openai_api_key
)
return
get_embedding
(
query
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
def
_get_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Get text embedding."""
"""Get text embedding."""
...
@@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding):
...
@@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding):
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
get_embedding
(
text
,
engine
=
engine
,
openai_api_key
=
self
.
openai_api_key
)
return
get_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
async
def
_aget_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
async
def
_aget_text_embedding
(
self
,
text
:
str
)
->
List
[
float
]:
"""Asynchronously get text embedding."""
"""Asynchronously get text embedding."""
...
@@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding):
...
@@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding):
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
return
await
aget_embedding
(
text
,
engine
=
engine
,
openai_api_key
=
self
.
openai_api_key
)
return
await
aget_embedding
(
text
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
def
_get_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
def
_get_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
"""Get text embeddings.
"""Get text embeddings.
...
@@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding):
...
@@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding):
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
get_embeddings
(
texts
,
engine
=
engine
,
openai_api_key
=
self
.
openai_api_key
)
embeddings
=
get_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
return
embeddings
async
def
_aget_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
async
def
_aget_text_embeddings
(
self
,
texts
:
List
[
str
])
->
List
[
List
[
float
]]:
...
@@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding):
...
@@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding):
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
if
key
not
in
_TEXT_MODE_MODEL_DICT
:
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
raise
ValueError
(
f
"Invalid mode, model combination: {key}"
)
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
engine
=
_TEXT_MODE_MODEL_DICT
[
key
]
embeddings
=
await
aget_embeddings
(
texts
,
engine
=
engine
,
openai_api_key
=
self
.
openai_api_key
)
embeddings
=
await
aget_embeddings
(
texts
,
engine
=
engine
,
api_key
=
self
.
openai_api_key
,
api_type
=
self
.
openai_api_type
,
api_version
=
self
.
openai_api_version
,
api_base
=
self
.
openai_api_base
)
return
embeddings
return
embeddings
api/core/index/index_builder.py
View file @
f68b05d5
...
@@ -33,8 +33,11 @@ class IndexBuilder:
...
@@ -33,8 +33,11 @@ class IndexBuilder:
max_chunk_overlap
=
20
max_chunk_overlap
=
20
)
)
provider
=
LLMBuilder
.
get_default_provider
(
tenant_id
)
model_credentials
=
LLMBuilder
.
get_model_credentials
(
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
tenant_id
,
tenant_id
=
tenant_id
,
model_provider
=
provider
,
model_name
=
'text-embedding-ada-002'
model_name
=
'text-embedding-ada-002'
)
)
...
...
api/core/llm/llm_builder.py
View file @
f68b05d5
...
@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager
...
@@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager
from
langchain.llms.fake
import
FakeListLLM
from
langchain.llms.fake
import
FakeListLLM
from
core.constant
import
llm_constant
from
core.constant
import
llm_constant
from
core.llm.error
import
ProviderTokenNotInitError
from
core.llm.provider.base
import
BaseProvider
from
core.llm.provider.llm_provider_service
import
LLMProviderService
from
core.llm.provider.llm_provider_service
import
LLMProviderService
from
core.llm.streamable_azure_chat_open_ai
import
StreamableAzureChatOpenAI
from
core.llm.streamable_azure_open_ai
import
StreamableAzureOpenAI
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
models.provider
import
ProviderType
class
LLMBuilder
:
class
LLMBuilder
:
...
@@ -31,16 +36,23 @@ class LLMBuilder:
...
@@ -31,16 +36,23 @@ class LLMBuilder:
if
model_name
==
'fake'
:
if
model_name
==
'fake'
:
return
FakeListLLM
(
responses
=
[])
return
FakeListLLM
(
responses
=
[])
provider
=
cls
.
get_default_provider
(
tenant_id
)
mode
=
cls
.
get_mode_by_model
(
model_name
)
mode
=
cls
.
get_mode_by_model
(
model_name
)
if
mode
==
'chat'
:
if
mode
==
'chat'
:
# llm_cls = StreamableAzureChatOpenAI
if
provider
==
'openai'
:
llm_cls
=
StreamableChatOpenAI
llm_cls
=
StreamableChatOpenAI
else
:
llm_cls
=
StreamableAzureChatOpenAI
elif
mode
==
'completion'
:
elif
mode
==
'completion'
:
llm_cls
=
StreamableOpenAI
if
provider
==
'openai'
:
llm_cls
=
StreamableOpenAI
else
:
llm_cls
=
StreamableAzureOpenAI
else
:
else
:
raise
ValueError
(
f
"model name {model_name} is not supported."
)
raise
ValueError
(
f
"model name {model_name} is not supported."
)
model_credentials
=
cls
.
get_model_credentials
(
tenant_id
,
model_name
)
model_credentials
=
cls
.
get_model_credentials
(
tenant_id
,
provider
,
model_name
)
return
llm_cls
(
return
llm_cls
(
model_name
=
model_name
,
model_name
=
model_name
,
...
@@ -86,18 +98,31 @@ class LLMBuilder:
...
@@ -86,18 +98,31 @@ class LLMBuilder:
raise
ValueError
(
f
"model name {model_name} is not supported."
)
raise
ValueError
(
f
"model name {model_name} is not supported."
)
@
classmethod
@
classmethod
def
get_model_credentials
(
cls
,
tenant_id
:
str
,
model_name
:
str
)
->
dict
:
def
get_model_credentials
(
cls
,
tenant_id
:
str
,
model_
provider
:
str
,
model_
name
:
str
)
->
dict
:
"""
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
Raises an exception if the model_name is not found or if the provider is not found.
"""
"""
if
not
model_name
:
if
not
model_name
:
raise
Exception
(
'model name not found'
)
raise
Exception
(
'model name not found'
)
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
if
model_name
not
in
llm_constant
.
models
:
# model_provider = llm_constant.models[model_name]
raise
Exception
(
'model {} not found'
.
format
(
model_name
))
model_provider
=
llm_constant
.
models
[
model_name
]
provider_service
=
LLMProviderService
(
tenant_id
=
tenant_id
,
provider_name
=
model_provider
)
provider_service
=
LLMProviderService
(
tenant_id
=
tenant_id
,
provider_name
=
model_provider
)
return
provider_service
.
get_credentials
(
model_name
)
return
provider_service
.
get_credentials
(
model_name
)
@
classmethod
def
get_default_provider
(
cls
,
tenant_id
:
str
)
->
str
:
provider
=
BaseProvider
.
get_valid_provider
(
tenant_id
)
if
not
provider
:
raise
ProviderTokenNotInitError
()
if
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
:
provider_name
=
'openai'
else
:
provider_name
=
provider
.
provider_name
return
provider_name
api/core/llm/provider/azure_provider.py
View file @
f68b05d5
...
@@ -36,10 +36,9 @@ class AzureProvider(BaseProvider):
...
@@ -36,10 +36,9 @@ class AzureProvider(BaseProvider):
"""
"""
Returns the API credentials for Azure OpenAI as a dictionary.
Returns the API credentials for Azure OpenAI as a dictionary.
"""
"""
encrypted_config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
=
json
.
loads
(
encrypted_config
)
config
[
'openai_api_type'
]
=
'azure'
config
[
'openai_api_type'
]
=
'azure'
config
[
'deployment_name'
]
=
model_id
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
return
config
return
config
def
get_provider_name
(
self
):
def
get_provider_name
(
self
):
...
@@ -51,12 +50,11 @@ class AzureProvider(BaseProvider):
...
@@ -51,12 +50,11 @@ class AzureProvider(BaseProvider):
"""
"""
try
:
try
:
config
=
self
.
get_provider_api_key
()
config
=
self
.
get_provider_api_key
()
config
=
json
.
loads
(
config
)
except
:
except
:
config
=
{
config
=
{
'openai_api_type'
:
'azure'
,
'openai_api_type'
:
'azure'
,
'openai_api_version'
:
'2023-03-15-preview'
,
'openai_api_version'
:
'2023-03-15-preview'
,
'openai_api_base'
:
'https://
foo.microsoft.com/bar
'
,
'openai_api_base'
:
'https://
<your-domain-prefix>.openai.azure.com/
'
,
'openai_api_key'
:
''
'openai_api_key'
:
''
}
}
...
@@ -65,7 +63,7 @@ class AzureProvider(BaseProvider):
...
@@ -65,7 +63,7 @@ class AzureProvider(BaseProvider):
config
=
{
config
=
{
'openai_api_type'
:
'azure'
,
'openai_api_type'
:
'azure'
,
'openai_api_version'
:
'2023-03-15-preview'
,
'openai_api_version'
:
'2023-03-15-preview'
,
'openai_api_base'
:
'https://
foo.microsoft.com/bar
'
,
'openai_api_base'
:
'https://
<your-domain-prefix>.openai.azure.com/
'
,
'openai_api_key'
:
''
'openai_api_key'
:
''
}
}
...
...
api/core/llm/provider/base.py
View file @
f68b05d5
...
@@ -14,7 +14,7 @@ class BaseProvider(ABC):
...
@@ -14,7 +14,7 @@ class BaseProvider(ABC):
def
__init__
(
self
,
tenant_id
:
str
):
def
__init__
(
self
,
tenant_id
:
str
):
self
.
tenant_id
=
tenant_id
self
.
tenant_id
=
tenant_id
def
get_provider_api_key
(
self
,
model_id
:
Optional
[
str
]
=
None
,
prefer_custom
:
bool
=
True
)
->
str
:
def
get_provider_api_key
(
self
,
model_id
:
Optional
[
str
]
=
None
,
prefer_custom
:
bool
=
True
)
->
Union
[
str
|
dict
]
:
"""
"""
Returns the decrypted API key for the given tenant_id and provider_name.
Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
...
@@ -43,23 +43,35 @@ class BaseProvider(ABC):
...
@@ -43,23 +43,35 @@ class BaseProvider(ABC):
Returns the Provider instance for the given tenant_id and provider_name.
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
"""
providers
=
db
.
session
.
query
(
Provider
)
.
filter
(
return
BaseProvider
.
get_valid_provider
(
self
.
tenant_id
,
self
.
get_provider_name
()
.
value
,
prefer_custom
)
Provider
.
tenant_id
==
self
.
tenant_id
,
Provider
.
provider_name
==
self
.
get_provider_name
()
.
value
@
classmethod
)
.
order_by
(
Provider
.
provider_type
.
desc
()
if
prefer_custom
else
Provider
.
provider_type
)
.
all
()
def
get_valid_provider
(
cls
,
tenant_id
:
str
,
provider_name
:
str
=
None
,
prefer_custom
:
bool
=
False
)
->
Optional
[
Provider
]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
query
=
db
.
session
.
query
(
Provider
)
.
filter
(
Provider
.
tenant_id
==
tenant_id
)
if
provider_name
:
query
=
query
.
filter
(
Provider
.
provider_name
==
provider_name
)
providers
=
query
.
order_by
(
Provider
.
provider_type
.
desc
()
if
prefer_custom
else
Provider
.
provider_type
)
.
all
()
custom_provider
=
None
custom_provider
=
None
system_provider
=
None
system_provider
=
None
for
provider
in
providers
:
for
provider
in
providers
:
if
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
:
if
provider
.
provider_type
==
ProviderType
.
CUSTOM
.
value
and
provider
.
is_valid
and
provider
.
encrypted_config
:
custom_provider
=
provider
custom_provider
=
provider
elif
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
:
elif
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
and
provider
.
is_valid
:
system_provider
=
provider
system_provider
=
provider
if
custom_provider
and
custom_provider
.
is_valid
and
custom_provider
.
encrypted_config
:
if
custom_provider
:
return
custom_provider
return
custom_provider
elif
system_provider
and
system_provider
.
is_valid
:
elif
system_provider
:
return
system_provider
return
system_provider
else
:
else
:
return
None
return
None
...
@@ -80,7 +92,7 @@ class BaseProvider(ABC):
...
@@ -80,7 +92,7 @@ class BaseProvider(ABC):
try
:
try
:
config
=
self
.
get_provider_api_key
()
config
=
self
.
get_provider_api_key
()
except
:
except
:
config
=
'
THIS-IS-A-MOCK-TOKEN
'
config
=
''
if
obfuscated
:
if
obfuscated
:
return
self
.
obfuscated_token
(
config
)
return
self
.
obfuscated_token
(
config
)
...
...
api/core/llm/streamable_azure_chat_open_ai.py
View file @
f68b05d5
import
requests
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
AzureChatOpenAI
from
langchain.chat_models
import
AzureChatOpenAI
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Dict
,
Any
from
pydantic
import
root_validator
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
class
StreamableAzureChatOpenAI
(
AzureChatOpenAI
):
class
StreamableAzureChatOpenAI
(
AzureChatOpenAI
):
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
try
:
import
openai
except
ImportError
:
raise
ValueError
(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try
:
values
[
"client"
]
=
openai
.
ChatCompletion
except
AttributeError
:
raise
ValueError
(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if
values
[
"n"
]
<
1
:
raise
ValueError
(
"n must be at least 1."
)
if
values
[
"n"
]
>
1
and
values
[
"streaming"
]:
raise
ValueError
(
"n must be 1 when streaming."
)
return
values
@
property
def
_default_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the default parameters for calling OpenAI API."""
return
{
**
super
()
.
_default_params
,
"engine"
:
self
.
deployment_name
,
"api_type"
:
self
.
openai_api_type
,
"api_base"
:
self
.
openai_api_base
,
"api_version"
:
self
.
openai_api_version
,
"api_key"
:
self
.
openai_api_key
,
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}
def
get_messages_tokens
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
def
get_messages_tokens
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in a list of messages.
"""Get the number of tokens in a list of messages.
...
...
api/core/llm/streamable_azure_open_ai.py
0 → 100644
View file @
f68b05d5
import
os
from
langchain.llms
import
AzureOpenAI
from
langchain.schema
import
LLMResult
from
typing
import
Optional
,
List
,
Dict
,
Mapping
,
Any
from
pydantic
import
root_validator
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
class
StreamableAzureOpenAI
(
AzureOpenAI
):
openai_api_type
:
str
=
"azure"
openai_api_version
:
str
=
""
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
try
:
import
openai
values
[
"client"
]
=
openai
.
Completion
except
ImportError
:
raise
ValueError
(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if
values
[
"streaming"
]
and
values
[
"n"
]
>
1
:
raise
ValueError
(
"Cannot stream results when n > 1."
)
if
values
[
"streaming"
]
and
values
[
"best_of"
]
>
1
:
raise
ValueError
(
"Cannot stream results when best_of > 1."
)
return
values
@
property
def
_invocation_params
(
self
)
->
Dict
[
str
,
Any
]:
return
{
**
super
()
.
_invocation_params
,
**
{
"api_type"
:
self
.
openai_api_type
,
"api_base"
:
self
.
openai_api_base
,
"api_version"
:
self
.
openai_api_version
,
"api_key"
:
self
.
openai_api_key
,
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
@
property
def
_identifying_params
(
self
)
->
Mapping
[
str
,
Any
]:
return
{
**
super
()
.
_identifying_params
,
**
{
"api_type"
:
self
.
openai_api_type
,
"api_base"
:
self
.
openai_api_base
,
"api_version"
:
self
.
openai_api_version
,
"api_key"
:
self
.
openai_api_key
,
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
@
handle_llm_exceptions
def
generate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
LLMResult
:
return
super
()
.
generate
(
prompts
,
stop
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
prompts
,
stop
)
api/core/llm/streamable_chat_open_ai.py
View file @
f68b05d5
import
os
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.chat_models
import
ChatOpenAI
from
langchain.chat_models
import
ChatOpenAI
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Dict
,
Any
from
pydantic
import
root_validator
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
class
StreamableChatOpenAI
(
ChatOpenAI
):
class
StreamableChatOpenAI
(
ChatOpenAI
):
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
try
:
import
openai
except
ImportError
:
raise
ValueError
(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try
:
values
[
"client"
]
=
openai
.
ChatCompletion
except
AttributeError
:
raise
ValueError
(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if
values
[
"n"
]
<
1
:
raise
ValueError
(
"n must be at least 1."
)
if
values
[
"n"
]
>
1
and
values
[
"streaming"
]:
raise
ValueError
(
"n must be 1 when streaming."
)
return
values
@
property
def
_default_params
(
self
)
->
Dict
[
str
,
Any
]:
"""Get the default parameters for calling OpenAI API."""
return
{
**
super
()
.
_default_params
,
"api_type"
:
'openai'
,
"api_base"
:
os
.
environ
.
get
(
"OPENAI_API_BASE"
,
"https://api.openai.com/v1"
),
"api_version"
:
None
,
"api_key"
:
self
.
openai_api_key
,
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}
def
get_messages_tokens
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
def
get_messages_tokens
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in a list of messages.
"""Get the number of tokens in a list of messages.
...
...
api/core/llm/streamable_open_ai.py
View file @
f68b05d5
import
os
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Dict
,
Any
,
Mapping
from
langchain
import
OpenAI
from
langchain
import
OpenAI
from
pydantic
import
root_validator
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
class
StreamableOpenAI
(
OpenAI
):
class
StreamableOpenAI
(
OpenAI
):
@
root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
"""Validate that api key and python package exists in environment."""
try
:
import
openai
values
[
"client"
]
=
openai
.
Completion
except
ImportError
:
raise
ValueError
(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if
values
[
"streaming"
]
and
values
[
"n"
]
>
1
:
raise
ValueError
(
"Cannot stream results when n > 1."
)
if
values
[
"streaming"
]
and
values
[
"best_of"
]
>
1
:
raise
ValueError
(
"Cannot stream results when best_of > 1."
)
return
values
@
property
def
_invocation_params
(
self
)
->
Dict
[
str
,
Any
]:
return
{
**
super
()
.
_invocation_params
,
**
{
"api_type"
:
'openai'
,
"api_base"
:
os
.
environ
.
get
(
"OPENAI_API_BASE"
,
"https://api.openai.com/v1"
),
"api_version"
:
None
,
"api_key"
:
self
.
openai_api_key
,
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
@
property
def
_identifying_params
(
self
)
->
Mapping
[
str
,
Any
]:
return
{
**
super
()
.
_identifying_params
,
**
{
"api_type"
:
'openai'
,
"api_base"
:
os
.
environ
.
get
(
"OPENAI_API_BASE"
,
"https://api.openai.com/v1"
),
"api_version"
:
None
,
"api_key"
:
self
.
openai_api_key
,
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
@
handle_llm_exceptions
@
handle_llm_exceptions
def
generate
(
def
generate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
...
...
web/app/components/header/account-setting/provider-page/azure-provider/index.tsx
View file @
f68b05d5
...
@@ -20,7 +20,7 @@ const AzureProvider = ({
...
@@ -20,7 +20,7 @@ const AzureProvider = ({
const
[
token
,
setToken
]
=
useState
(
provider
.
token
as
ProviderAzureToken
||
{})
const
[
token
,
setToken
]
=
useState
(
provider
.
token
as
ProviderAzureToken
||
{})
const
handleFocus
=
()
=>
{
const
handleFocus
=
()
=>
{
if
(
token
===
provider
.
token
)
{
if
(
token
===
provider
.
token
)
{
token
.
azure
_api_key
=
''
token
.
openai
_api_key
=
''
setToken
({...
token
})
setToken
({...
token
})
onTokenChange
({...
token
})
onTokenChange
({...
token
})
}
}
...
@@ -35,31 +35,17 @@ const AzureProvider = ({
...
@@ -35,31 +35,17 @@ const AzureProvider = ({
<
div
className=
'px-4 py-3'
>
<
div
className=
'px-4 py-3'
>
<
ProviderInput
<
ProviderInput
className=
'mb-4'
className=
'mb-4'
name=
{
t
(
'common.provider.azure.resourceName'
)
}
name=
{
t
(
'common.provider.azure.apiBase'
)
}
placeholder=
{
t
(
'common.provider.azure.resourceNamePlaceholder'
)
}
placeholder=
{
t
(
'common.provider.azure.apiBasePlaceholder'
)
}
value=
{
token
.
azure_api_base
}
value=
{
token
.
openai_api_base
}
onChange=
{
(
v
)
=>
handleChange
(
'azure_api_base'
,
v
)
}
onChange=
{
(
v
)
=>
handleChange
(
'openai_api_base'
,
v
)
}
/>
<
ProviderInput
className=
'mb-4'
name=
{
t
(
'common.provider.azure.deploymentId'
)
}
placeholder=
{
t
(
'common.provider.azure.deploymentIdPlaceholder'
)
}
value=
{
token
.
azure_api_type
}
onChange=
{
v
=>
handleChange
(
'azure_api_type'
,
v
)
}
/>
<
ProviderInput
className=
'mb-4'
name=
{
t
(
'common.provider.azure.apiVersion'
)
}
placeholder=
{
t
(
'common.provider.azure.apiVersionPlaceholder'
)
}
value=
{
token
.
azure_api_version
}
onChange=
{
v
=>
handleChange
(
'azure_api_version'
,
v
)
}
/>
/>
<
ProviderValidateTokenInput
<
ProviderValidateTokenInput
className=
'mb-4'
className=
'mb-4'
name=
{
t
(
'common.provider.azure.apiKey'
)
}
name=
{
t
(
'common.provider.azure.apiKey'
)
}
placeholder=
{
t
(
'common.provider.azure.apiKeyPlaceholder'
)
}
placeholder=
{
t
(
'common.provider.azure.apiKeyPlaceholder'
)
}
value=
{
token
.
azure
_api_key
}
value=
{
token
.
openai
_api_key
}
onChange=
{
v
=>
handleChange
(
'
azure
_api_key'
,
v
)
}
onChange=
{
v
=>
handleChange
(
'
openai
_api_key'
,
v
)
}
onFocus=
{
handleFocus
}
onFocus=
{
handleFocus
}
onValidatedStatus=
{
onValidatedStatus
}
onValidatedStatus=
{
onValidatedStatus
}
providerName=
{
provider
.
provider_name
}
providerName=
{
provider
.
provider_name
}
...
@@ -72,4 +58,4 @@ const AzureProvider = ({
...
@@ -72,4 +58,4 @@ const AzureProvider = ({
)
)
}
}
export
default
AzureProvider
export
default
AzureProvider
\ No newline at end of file
web/app/components/header/account-setting/provider-page/provider-item/index.tsx
View file @
f68b05d5
...
@@ -33,12 +33,12 @@ const ProviderItem = ({
...
@@ -33,12 +33,12 @@ const ProviderItem = ({
const
{
notify
}
=
useContext
(
ToastContext
)
const
{
notify
}
=
useContext
(
ToastContext
)
const
[
token
,
setToken
]
=
useState
<
ProviderAzureToken
|
string
>
(
const
[
token
,
setToken
]
=
useState
<
ProviderAzureToken
|
string
>
(
provider
.
provider_name
===
'azure_openai'
provider
.
provider_name
===
'azure_openai'
?
{
azure_api_base
:
''
,
azure_api_type
:
''
,
azure_api_version
:
''
,
azure_api_key
:
''
}
?
{
openai_api_base
:
''
,
openai_api_key
:
''
}
:
''
:
''
)
)
const
id
=
`
${
provider
.
provider_name
}
-
${
provider
.
provider_type
}
`
const
id
=
`
${
provider
.
provider_name
}
-
${
provider
.
provider_type
}
`
const
isOpen
=
id
===
activeId
const
isOpen
=
id
===
activeId
const
providerKey
=
provider
.
provider_name
===
'azure_openai'
?
(
provider
.
token
as
ProviderAzureToken
)?.
azure
_api_key
:
provider
.
token
const
providerKey
=
provider
.
provider_name
===
'azure_openai'
?
(
provider
.
token
as
ProviderAzureToken
)?.
openai
_api_key
:
provider
.
token
const
comingSoon
=
false
const
comingSoon
=
false
const
isValid
=
provider
.
is_valid
const
isValid
=
provider
.
is_valid
...
@@ -135,4 +135,4 @@ const ProviderItem = ({
...
@@ -135,4 +135,4 @@ const ProviderItem = ({
)
)
}
}
export
default
ProviderItem
export
default
ProviderItem
\ No newline at end of file
web/i18n/lang/common.en.ts
View file @
f68b05d5
...
@@ -148,12 +148,8 @@ const translation = {
...
@@ -148,12 +148,8 @@ const translation = {
editKey
:
'Edit'
,
editKey
:
'Edit'
,
invalidApiKey
:
'Invalid API key'
,
invalidApiKey
:
'Invalid API key'
,
azure
:
{
azure
:
{
resourceName
:
'Resource Name'
,
apiBase
:
'API Base'
,
resourceNamePlaceholder
:
'The name of your Azure OpenAI Resource.'
,
apiBasePlaceholder
:
'The API Base URL of your Azure OpenAI Resource.'
,
deploymentId
:
'Deployment ID'
,
deploymentIdPlaceholder
:
'The deployment name you chose when you deployed the model.'
,
apiVersion
:
'API Version'
,
apiVersionPlaceholder
:
'The API version to use for this operation.'
,
apiKey
:
'API Key'
,
apiKey
:
'API Key'
,
apiKeyPlaceholder
:
'Enter your API key here'
,
apiKeyPlaceholder
:
'Enter your API key here'
,
helpTip
:
'Learn Azure OpenAI Service'
,
helpTip
:
'Learn Azure OpenAI Service'
,
...
...
web/i18n/lang/common.zh.ts
View file @
f68b05d5
...
@@ -149,14 +149,10 @@ const translation = {
...
@@ -149,14 +149,10 @@ const translation = {
editKey
:
'编辑'
,
editKey
:
'编辑'
,
invalidApiKey
:
'无效的 API 密钥'
,
invalidApiKey
:
'无效的 API 密钥'
,
azure
:
{
azure
:
{
resourceName
:
'Resource Name'
,
apiBase
:
'API Base'
,
resourceNamePlaceholder
:
'The name of your Azure OpenAI Resource.'
,
apiBasePlaceholder
:
'输入您的 Azure OpenAI API Base 地址'
,
deploymentId
:
'Deployment ID'
,
deploymentIdPlaceholder
:
'The deployment name you chose when you deployed the model.'
,
apiVersion
:
'API Version'
,
apiVersionPlaceholder
:
'The API version to use for this operation.'
,
apiKey
:
'API Key'
,
apiKey
:
'API Key'
,
apiKeyPlaceholder
:
'
Enter your API key here
'
,
apiKeyPlaceholder
:
'
输入你的 API 密钥
'
,
helpTip
:
'了解 Azure OpenAI Service'
,
helpTip
:
'了解 Azure OpenAI Service'
,
},
},
openaiHosted
:
{
openaiHosted
:
{
...
...
web/models/common.ts
View file @
f68b05d5
...
@@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l
...
@@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l
}
}
export
type
ProviderAzureToken
=
{
export
type
ProviderAzureToken
=
{
azure_api_base
:
string
openai_api_base
:
string
azure_api_key
:
string
openai_api_key
:
string
azure_api_type
:
string
azure_api_version
:
string
}
}
export
type
Provider
=
{
export
type
Provider
=
{
provider_name
:
string
provider_name
:
string
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment