Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
8226c765
Commit
8226c765
authored
Jul 14, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: completed anthropic develop
parent
7f320f91
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
44 changed files
with
475 additions
and
280 deletions
+475
-280
config.py
api/config.py
+1
-0
audio.py
api/controllers/console/app/audio.py
+2
-2
completion.py
api/controllers/console/app/completion.py
+6
-6
generator.py
api/controllers/console/app/generator.py
+4
-4
message.py
api/controllers/console/app/message.py
+6
-6
datasets_document.py
api/controllers/console/datasets/datasets_document.py
+4
-4
hit_testing.py
api/controllers/console/datasets/hit_testing.py
+2
-2
audio.py
api/controllers/console/explore/audio.py
+2
-2
completion.py
api/controllers/console/explore/completion.py
+6
-6
message.py
api/controllers/console/explore/message.py
+6
-6
providers.py
api/controllers/console/workspace/providers.py
+2
-2
audio.py
api/controllers/service_api/app/audio.py
+2
-2
completion.py
api/controllers/service_api/app/completion.py
+6
-6
document.py
api/controllers/service_api/dataset/document.py
+2
-2
audio.py
api/controllers/web/audio.py
+2
-2
completion.py
api/controllers/web/completion.py
+6
-6
message.py
api/controllers/web/message.py
+6
-6
__init__.py
api/core/__init__.py
+8
-0
llm_callback_handler.py
api/core/callback_handler/llm_callback_handler.py
+1
-1
completion.py
api/core/completion.py
+36
-20
llm_constant.py
api/core/constant/llm_constant.py
+19
-2
conversation_message_task.py
api/core/conversation_message_task.py
+3
-2
llm_generator.py
api/core/generator/llm_generator.py
+1
-1
index.py
api/core/index/index.py
+1
-1
error.py
api/core/llm/error.py
+3
-0
llm_builder.py
api/core/llm/llm_builder.py
+44
-28
anthropic_provider.py
api/core/llm/provider/anthropic_provider.py
+116
-11
azure_provider.py
api/core/llm/provider/azure_provider.py
+0
-1
base.py
api/core/llm/provider/base.py
+12
-11
openai_provider.py
api/core/llm/provider/openai_provider.py
+11
-0
streamable_azure_chat_open_ai.py
api/core/llm/streamable_azure_chat_open_ai.py
+19
-36
streamable_azure_open_ai.py
api/core/llm/streamable_azure_open_ai.py
+5
-11
streamable_chat_anthropic.py
api/core/llm/streamable_chat_anthropic.py
+39
-0
streamable_chat_open_ai.py
api/core/llm/streamable_chat_open_ai.py
+17
-34
streamable_open_ai.py
api/core/llm/streamable_open_ai.py
+5
-11
whisper.py
api/core/llm/whisper.py
+3
-2
anthropic_wrapper.py
api/core/llm/wrappers/anthropic_wrapper.py
+27
-0
openai_wrapper.py
api/core/llm/wrappers/openai_wrapper.py
+1
-25
read_only_conversation_token_db_buffer_shared_memory.py
...y/read_only_conversation_token_db_buffer_shared_memory.py
+5
-5
dataset_index_tool.py
api/core/tool/dataset_index_tool.py
+2
-2
requirements.txt
api/requirements.txt
+2
-1
app_model_config_service.py
api/services/app_model_config_service.py
+28
-4
audio_service.py
api/services/audio_service.py
+1
-6
hit_testing_service.py
api/services/hit_testing_service.py
+1
-1
No files found.
api/config.py
View file @
8226c765
...
...
@@ -191,6 +191,7 @@ class Config:
# hosted provider credentials
self
.
OPENAI_API_KEY
=
get_env
(
'OPENAI_API_KEY'
)
self
.
ANTHROPIC_API_KEY
=
get_env
(
'ANTHROPIC_API_KEY'
)
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
...
...
api/controllers/console/app/audio.py
View file @
8226c765
...
...
@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
raise
UnsupportedAudioTypeError
()
except
ProviderNotSupportSpeechToTextServiceError
:
raise
ProviderNotSupportSpeechToTextError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/app/completion.py
View file @
8226c765
...
...
@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppUnavailableError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
as
ex
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
(
ex
.
description
))
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/app/generator.py
View file @
8226c765
...
...
@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
account
.
current_tenant_id
,
args
[
'prompt_template'
]
)
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
args
[
'audiences'
],
args
[
'hoping_to_solve'
]
)
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/app/message.py
View file @
8226c765
...
...
@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
raise
NotFound
(
"Message Not Exists."
)
except
MoreLikeThisDisabledError
:
raise
AppMoreLikeThisDisabledError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
NotFound
(
"Message Not Exists."
))
.
get_json
())
+
"
\n\n
"
except
MoreLikeThisDisabledError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppMoreLikeThisDisabledError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
as
ex
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
(
ex
.
description
))
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
...
...
@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
raise
NotFound
(
"Message not found"
)
except
ConversationNotExistsError
:
raise
NotFound
(
"Conversation not found"
)
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/datasets/datasets_document.py
View file @
8226c765
...
...
@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
try
:
documents
,
batch
=
DocumentService
.
save_document_with_dataset_id
(
dataset
,
args
,
current_user
)
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
document_data
=
args
,
account
=
current_user
)
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/datasets/hit_testing.py
View file @
8226c765
...
...
@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
return
{
"query"
:
response
[
'query'
],
'records'
:
marshal
(
response
[
'records'
],
hit_testing_record_fields
)}
except
services
.
errors
.
index
.
IndexNotInitializedError
:
raise
DatasetNotInitializedError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/explore/audio.py
View file @
8226c765
...
...
@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
raise
UnsupportedAudioTypeError
()
except
ProviderNotSupportSpeechToTextServiceError
:
raise
ProviderNotSupportSpeechToTextError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/explore/completion.py
View file @
8226c765
...
...
@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppUnavailableError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
as
ex
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
(
ex
.
description
))
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/explore/message.py
View file @
8226c765
...
...
@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise
NotFound
(
"Message Not Exists."
)
except
MoreLikeThisDisabledError
:
raise
AppMoreLikeThisDisabledError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
NotFound
(
"Message Not Exists."
))
.
get_json
())
+
"
\n\n
"
except
MoreLikeThisDisabledError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppMoreLikeThisDisabledError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
as
ex
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
(
ex
.
description
))
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
...
...
@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise
NotFound
(
"Conversation not found"
)
except
SuggestedQuestionsAfterAnswerDisabledError
:
raise
AppSuggestedQuestionsAfterAnswerDisabledError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/console/workspace/providers.py
View file @
8226c765
...
...
@@ -133,7 +133,7 @@ class ProviderTokenApi(Resource):
db
.
session
.
commit
()
if
provider
in
[
ProviderName
.
A
NTHROPIC
.
value
,
ProviderName
.
A
ZURE_OPENAI
.
value
,
ProviderName
.
COHERE
.
value
,
if
provider
in
[
ProviderName
.
AZURE_OPENAI
.
value
,
ProviderName
.
COHERE
.
value
,
ProviderName
.
HUGGINGFACEHUB
.
value
]:
return
{
'result'
:
'success'
,
'warning'
:
'MOCK: This provider is not supported yet.'
},
201
...
...
@@ -157,7 +157,7 @@ class ProviderTokenValidateApi(Resource):
args
=
parser
.
parse_args
()
# todo: remove this when the provider is supported
if
provider
in
[
ProviderName
.
ANTHROPIC
.
value
,
ProviderName
.
COHERE
.
value
,
if
provider
in
[
ProviderName
.
COHERE
.
value
,
ProviderName
.
HUGGINGFACEHUB
.
value
]:
return
{
'result'
:
'success'
,
'warning'
:
'MOCK: This provider is not supported yet.'
}
...
...
api/controllers/service_api/app/audio.py
View file @
8226c765
...
...
@@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
raise
UnsupportedAudioTypeError
()
except
ProviderNotSupportSpeechToTextServiceError
:
raise
ProviderNotSupportSpeechToTextError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/service_api/app/completion.py
View file @
8226c765
...
...
@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppUnavailableError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
as
ex
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
(
ex
.
description
))
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/service_api/dataset/document.py
View file @
8226c765
...
...
@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
dataset_process_rule
=
dataset
.
latest_process_rule
,
created_from
=
'api'
)
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
document
=
documents
[
0
]
if
doc_type
and
doc_metadata
:
metadata_schema
=
DocumentService
.
DOCUMENT_METADATA_SCHEMA
[
doc_type
]
...
...
api/controllers/web/audio.py
View file @
8226c765
...
...
@@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
raise
UnsupportedAudioTypeError
()
except
ProviderNotSupportSpeechToTextServiceError
:
raise
ProviderNotSupportSpeechToTextError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/web/completion.py
View file @
8226c765
...
...
@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
raise
AppUnavailableError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except
services
.
errors
.
app_model_config
.
AppModelConfigBrokenError
:
logging
.
exception
(
"App model config broken."
)
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppUnavailableError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
as
ex
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
(
ex
.
description
))
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
...
...
api/controllers/web/message.py
View file @
8226c765
...
...
@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
raise
NotFound
(
"Message Not Exists."
)
except
MoreLikeThisDisabledError
:
raise
AppMoreLikeThisDisabledError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
NotFound
(
"Message Not Exists."
))
.
get_json
())
+
"
\n\n
"
except
MoreLikeThisDisabledError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
AppMoreLikeThisDisabledError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
())
.
get_json
())
+
"
\n\n
"
except
ProviderTokenNotInitError
as
ex
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderNotInitializeError
(
ex
.
description
))
.
get_json
())
+
"
\n\n
"
except
QuotaExceededError
:
yield
"data: "
+
json
.
dumps
(
api
.
handle_error
(
ProviderQuotaExceededError
())
.
get_json
())
+
"
\n\n
"
except
ModelCurrentlyNotSupportError
:
...
...
@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise
NotFound
(
"Conversation not found"
)
except
SuggestedQuestionsAfterAnswerDisabledError
:
raise
AppSuggestedQuestionsAfterAnswerDisabledError
()
except
ProviderTokenNotInitError
:
raise
ProviderNotInitializeError
()
except
ProviderTokenNotInitError
as
ex
:
raise
ProviderNotInitializeError
(
ex
.
description
)
except
QuotaExceededError
:
raise
ProviderQuotaExceededError
()
except
ModelCurrentlyNotSupportError
:
...
...
api/core/__init__.py
View file @
8226c765
...
...
@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
api_key
:
str
class
HostedAnthropicCredential
(
BaseModel
):
api_key
:
str
class
HostedLLMCredentials
(
BaseModel
):
openai
:
Optional
[
HostedOpenAICredential
]
=
None
anthropic
:
Optional
[
HostedAnthropicCredential
]
=
None
hosted_llm_credentials
=
HostedLLMCredentials
()
...
...
@@ -26,3 +31,6 @@ def init_app(app: Flask):
if
app
.
config
.
get
(
"OPENAI_API_KEY"
):
hosted_llm_credentials
.
openai
=
HostedOpenAICredential
(
api_key
=
app
.
config
.
get
(
"OPENAI_API_KEY"
))
if
app
.
config
.
get
(
"ANTHROPIC_API_KEY"
):
hosted_llm_credentials
.
anthropic
=
HostedAnthropicCredential
(
api_key
=
app
.
config
.
get
(
"ANTHROPIC_API_KEY"
))
api/core/callback_handler/llm_callback_handler.py
View file @
8226c765
...
...
@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
})
self
.
llm_message
.
prompt
=
real_prompts
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_
messages_token
s
(
messages
[
0
])
self
.
llm_message
.
prompt_tokens
=
self
.
llm
.
get_
num_tokens_from_message
s
(
messages
[
0
])
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
...
...
api/core/completion.py
View file @
8226c765
...
...
@@ -118,6 +118,7 @@ class Completion:
prompt
,
stop_words
=
cls
.
get_main_llm_prompt
(
mode
=
mode
,
llm
=
final_llm
,
model
=
app_model_config
.
model_dict
,
pre_prompt
=
app_model_config
.
pre_prompt
,
query
=
query
,
inputs
=
inputs
,
...
...
@@ -129,6 +130,7 @@ class Completion:
cls
.
recale_llm_max_tokens
(
final_llm
=
final_llm
,
model
=
app_model_config
.
model_dict
,
prompt
=
prompt
,
mode
=
mode
)
...
...
@@ -138,7 +140,8 @@ class Completion:
return
response
@
classmethod
def
get_main_llm_prompt
(
cls
,
mode
:
str
,
llm
:
BaseLanguageModel
,
pre_prompt
:
str
,
query
:
str
,
inputs
:
dict
,
def
get_main_llm_prompt
(
cls
,
mode
:
str
,
llm
:
BaseLanguageModel
,
model
:
dict
,
pre_prompt
:
str
,
query
:
str
,
inputs
:
dict
,
chain_output
:
Optional
[
str
],
memory
:
Optional
[
ReadOnlyConversationTokenDBBufferSharedMemory
])
->
\
Tuple
[
Union
[
str
|
List
[
BaseMessage
]],
Optional
[
List
[
str
]]]:
...
...
@@ -151,10 +154,11 @@ class Completion:
if
mode
==
'completion'
:
prompt_template
=
JinjaPromptTemplate
.
from_template
(
template
=
(
"""Use the following CONTEXT as your learned knowledge:
[CONTEXT]
template
=
(
"""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
...
...
@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
if
chain_output
:
human_inputs
[
'context'
]
=
chain_output
human_message_prompt
+=
"""Use the following CONTEXT as your learned knowledge.
[CONTEXT]
human_message_prompt
+=
"""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
[END CONTEXT]
</context>
When answer to user:
- If you don't know, just say that you don't know.
...
...
@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
if
pre_prompt
:
human_message_prompt
+=
pre_prompt
query_prompt
=
"
\n
Human: {{query}}
\n
AI
: "
query_prompt
=
"
\n
\n
Human: {{query}}
\n\n
Assistant
: "
if
memory
:
# append chat histories
...
...
@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
inputs
=
human_inputs
)
curr_message_tokens
=
memory
.
llm
.
get_messages_tokens
([
tmp_human_message
])
rest_tokens
=
llm_constant
.
max_context_token_length
[
memory
.
llm
.
model_name
]
\
-
memory
.
llm
.
max_tokens
-
curr_message_tokens
curr_message_tokens
=
memory
.
llm
.
get_num_tokens_from_messages
([
tmp_human_message
])
model_name
=
model
[
'name'
]
max_tokens
=
model
.
get
(
"completion_params"
)
.
get
(
'max_tokens'
)
rest_tokens
=
llm_constant
.
max_context_token_length
[
model_name
]
\
-
max_tokens
-
curr_message_tokens
rest_tokens
=
max
(
rest_tokens
,
0
)
histories
=
cls
.
get_history_messages_from_memory
(
memory
,
rest_tokens
)
...
...
@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt
+=
"
\n\n
"
+
histories
human_message_prompt
+=
"
\n\n
"
if
human_message_prompt
else
""
human_message_prompt
+=
"Here is the chat histories between human and assistant, "
\
"inside <histories></histories> XML tags.
\n\n
<histories>"
human_message_prompt
+=
histories
+
"</histories>"
human_message_prompt
+=
query_prompt
...
...
@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
model
=
app_model_config
.
model_dict
)
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
llm
.
model_name
]
max_tokens
=
llm
.
max_tokens
model_name
=
app_model_config
.
model_dict
.
get
(
"name"
)
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
model_name
]
max_tokens
=
app_model_config
.
model_dict
.
get
(
"completion_params"
)
.
get
(
'max_tokens'
)
# get prompt without memory and context
prompt
,
_
=
cls
.
get_main_llm_prompt
(
mode
=
mode
,
llm
=
llm
,
model
=
app_model_config
.
model_dict
,
pre_prompt
=
app_model_config
.
pre_prompt
,
query
=
query
,
inputs
=
inputs
,
...
...
@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
return
rest_tokens
@
classmethod
def
recale_llm_max_tokens
(
cls
,
final_llm
:
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]
,
def
recale_llm_max_tokens
(
cls
,
final_llm
:
BaseLanguageModel
,
model
:
dict
,
prompt
:
Union
[
str
,
List
[
BaseMessage
]],
mode
:
str
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
final_llm
.
model_name
]
max_tokens
=
final_llm
.
max_tokens
model_name
=
model
.
get
(
"name"
)
model_limited_tokens
=
llm_constant
.
max_context_token_length
[
model_name
]
max_tokens
=
model
.
get
(
"completion_params"
)
.
get
(
'max_tokens'
)
if
mode
==
'completion'
and
isinstance
(
final_llm
,
BaseLLM
):
prompt_tokens
=
final_llm
.
get_num_tokens
(
prompt
)
else
:
prompt_tokens
=
final_llm
.
get_
messages_token
s
(
prompt
)
prompt_tokens
=
final_llm
.
get_
num_tokens_from_message
s
(
prompt
)
if
prompt_tokens
+
max_tokens
>
model_limited_tokens
:
max_tokens
=
max
(
model_limited_tokens
-
prompt_tokens
,
16
)
...
...
@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
@
classmethod
def
generate_more_like_this
(
cls
,
task_id
:
str
,
app
:
App
,
message
:
Message
,
pre_prompt
:
str
,
app_model_config
:
AppModelConfig
,
user
:
Account
,
streaming
:
bool
):
llm
:
StreamableOpenAI
=
LLMBuilder
.
to_llm
(
llm
=
LLMBuilder
.
to_llm_from_model
(
tenant_id
=
app
.
tenant_id
,
model
_name
=
'gpt-3.5-turbo'
,
model
=
app_model_config
.
model_dict
,
streaming
=
streaming
)
...
...
@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
original_prompt
,
_
=
cls
.
get_main_llm_prompt
(
mode
=
"completion"
,
llm
=
llm
,
model
=
app_model_config
.
model_dict
,
pre_prompt
=
pre_prompt
,
query
=
message
.
query
,
inputs
=
message
.
inputs
,
...
...
@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
cls
.
recale_llm_max_tokens
(
final_llm
=
llm
,
model
=
app_model_config
.
model_dict
,
prompt
=
prompt
,
mode
=
'completion'
)
...
...
api/core/constant/llm_constant.py
View file @
8226c765
from
_decimal
import
Decimal
models
=
{
'claude-instant-1'
:
'anthropic'
,
# 100,000 tokens
'claude-2'
:
'anthropic'
,
# 100,000 tokens
'gpt-4'
:
'openai'
,
# 8,192 tokens
'gpt-4-32k'
:
'openai'
,
# 32,768 tokens
'gpt-3.5-turbo'
:
'openai'
,
# 4,096 tokens
...
...
@@ -10,10 +12,13 @@ models = {
'text-curie-001'
:
'openai'
,
# 2,049 tokens
'text-babbage-001'
:
'openai'
,
# 2,049 tokens
'text-ada-001'
:
'openai'
,
# 2,049 tokens
'text-embedding-ada-002'
:
'openai'
# 8191 tokens, 1536 dimensions
'text-embedding-ada-002'
:
'openai'
,
# 8191 tokens, 1536 dimensions
'whisper-1'
:
'openai'
}
max_context_token_length
=
{
'claude-instant-1'
:
100000
,
'claude-2'
:
100000
,
'gpt-4'
:
8192
,
'gpt-4-32k'
:
32768
,
'gpt-3.5-turbo'
:
4096
,
...
...
@@ -23,17 +28,21 @@ max_context_token_length = {
'text-curie-001'
:
2049
,
'text-babbage-001'
:
2049
,
'text-ada-001'
:
2049
,
'text-embedding-ada-002'
:
8191
'text-embedding-ada-002'
:
8191
,
}
models_by_mode
=
{
'chat'
:
[
'claude-instant-1'
,
# 100,000 tokens
'claude-2'
,
# 100,000 tokens
'gpt-4'
,
# 8,192 tokens
'gpt-4-32k'
,
# 32,768 tokens
'gpt-3.5-turbo'
,
# 4,096 tokens
'gpt-3.5-turbo-16k'
,
# 16,384 tokens
],
'completion'
:
[
'claude-instant-1'
,
# 100,000 tokens
'claude-2'
,
# 100,000 tokens
'gpt-4'
,
# 8,192 tokens
'gpt-4-32k'
,
# 32,768 tokens
'gpt-3.5-turbo'
,
# 4,096 tokens
...
...
@@ -52,6 +61,14 @@ models_by_mode = {
model_currency
=
'USD'
model_prices
=
{
'claude-instant-1'
:
{
'prompt'
:
Decimal
(
'0.00163'
),
'completion'
:
Decimal
(
'0.00551'
),
},
'claude-2'
:
{
'prompt'
:
Decimal
(
'0.01102'
),
'completion'
:
Decimal
(
'0.03268'
),
},
'gpt-4'
:
{
'prompt'
:
Decimal
(
'0.03'
),
'completion'
:
Decimal
(
'0.06'
),
...
...
api/core/conversation_message_task.py
View file @
8226c765
...
...
@@ -56,7 +56,7 @@ class ConversationMessageTask:
)
def
init
(
self
):
provider_name
=
LLMBuilder
.
get_default_provider
(
self
.
app
.
tenant_id
)
provider_name
=
LLMBuilder
.
get_default_provider
(
self
.
app
.
tenant_id
,
self
.
model_name
)
self
.
model_dict
[
'provider'
]
=
provider_name
override_model_configs
=
None
...
...
@@ -89,7 +89,7 @@ class ConversationMessageTask:
system_message
=
PromptBuilder
.
to_system_message
(
self
.
app_model_config
.
pre_prompt
,
self
.
inputs
)
system_instruction
=
system_message
.
content
llm
=
LLMBuilder
.
to_llm
(
self
.
tenant_id
,
self
.
model_name
)
system_instruction_tokens
=
llm
.
get_
messages_token
s
([
system_message
])
system_instruction_tokens
=
llm
.
get_
num_tokens_from_message
s
([
system_message
])
if
not
self
.
conversation
:
self
.
is_new_conversation
=
True
...
...
@@ -185,6 +185,7 @@ class ConversationMessageTask:
if
provider
and
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
:
db
.
session
.
query
(
Provider
)
.
filter
(
Provider
.
tenant_id
==
self
.
app
.
tenant_id
,
Provider
.
provider_name
==
provider
.
provider_name
,
Provider
.
quota_limit
>
Provider
.
quota_used
)
.
update
({
'quota_used'
:
Provider
.
quota_used
+
1
})
...
...
api/core/generator/llm_generator.py
View file @
8226c765
...
...
@@ -52,7 +52,7 @@ class LLMGenerator:
if
not
message
.
answer
:
continue
message_qa_text
=
"
Human:"
+
message
.
query
+
"
\n
AI:"
+
message
.
answer
+
"
\n
"
message_qa_text
=
"
\n\n
Human:"
+
message
.
query
+
"
\n\n
Assistant:"
+
message
.
answer
if
rest_tokens
-
TokenCalculator
.
get_num_tokens
(
model
,
context
+
message_qa_text
)
>
0
:
context
+=
message_qa_text
...
...
api/core/index/index.py
View file @
8226c765
...
...
@@ -17,7 +17,7 @@ class IndexBuilder:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
...
...
api/core/llm/error.py
View file @
8226c765
...
...
@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
"""
description
=
"Provider Token Not Init"
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
description
=
args
[
0
]
if
args
else
self
.
description
class
QuotaExceededError
(
Exception
):
"""
...
...
api/core/llm/llm_builder.py
View file @
8226c765
...
...
@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
from
core.llm.provider.llm_provider_service
import
LLMProviderService
from
core.llm.streamable_azure_chat_open_ai
import
StreamableAzureChatOpenAI
from
core.llm.streamable_azure_open_ai
import
StreamableAzureOpenAI
from
core.llm.streamable_chat_anthropic
import
StreamableChatAnthropic
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
from
models.provider
import
ProviderType
from
models.provider
import
ProviderType
,
ProviderName
class
LLMBuilder
:
...
...
@@ -32,43 +33,43 @@ class LLMBuilder:
@
classmethod
def
to_llm
(
cls
,
tenant_id
:
str
,
model_name
:
str
,
**
kwargs
)
->
Union
[
StreamableOpenAI
,
StreamableChatOpenAI
]:
provider
=
cls
.
get_default_provider
(
tenant_id
)
provider
=
cls
.
get_default_provider
(
tenant_id
,
model_name
)
model_credentials
=
cls
.
get_model_credentials
(
tenant_id
,
provider
,
model_name
)
llm_cls
=
None
mode
=
cls
.
get_mode_by_model
(
model_name
)
if
mode
==
'chat'
:
if
provider
==
'openai'
:
if
provider
==
ProviderName
.
OPENAI
.
value
:
llm_cls
=
StreamableChatOpenAI
el
s
e
:
el
if
provider
==
ProviderName
.
AZURE_OPENAI
.
valu
e
:
llm_cls
=
StreamableAzureChatOpenAI
elif
provider
==
ProviderName
.
ANTHROPIC
.
value
:
llm_cls
=
StreamableChatAnthropic
elif
mode
==
'completion'
:
if
provider
==
'openai'
:
if
provider
==
ProviderName
.
OPENAI
.
value
:
llm_cls
=
StreamableOpenAI
el
s
e
:
el
if
provider
==
ProviderName
.
AZURE_OPENAI
.
valu
e
:
llm_cls
=
StreamableAzureOpenAI
else
:
raise
ValueError
(
f
"model name {model_name} is not supported."
)
if
not
llm_cls
:
raise
ValueError
(
f
"model name {model_name} is not supported."
)
model_kwargs
=
{
'model_name'
:
model_name
,
'temperature'
:
kwargs
.
get
(
'temperature'
,
0
),
'max_tokens'
:
kwargs
.
get
(
'max_tokens'
,
256
),
'top_p'
:
kwargs
.
get
(
'top_p'
,
1
),
'frequency_penalty'
:
kwargs
.
get
(
'frequency_penalty'
,
0
),
'presence_penalty'
:
kwargs
.
get
(
'presence_penalty'
,
0
),
'callbacks'
:
kwargs
.
get
(
'callbacks'
,
None
),
'streaming'
:
kwargs
.
get
(
'streaming'
,
False
),
}
model_extras_kwargs
=
model_kwargs
if
mode
==
'completion'
else
{
'model_kwargs'
:
model_kwargs
}
model_kwargs
.
update
(
model_credentials
)
model_kwargs
=
llm_cls
.
get_kwargs_from_model_params
(
model_kwargs
)
return
llm_cls
(
model_name
=
model_name
,
temperature
=
kwargs
.
get
(
'temperature'
,
0
),
max_tokens
=
kwargs
.
get
(
'max_tokens'
,
256
),
**
model_extras_kwargs
,
callbacks
=
kwargs
.
get
(
'callbacks'
,
None
),
streaming
=
kwargs
.
get
(
'streaming'
,
False
),
# request_timeout=None
**
model_credentials
)
return
llm_cls
(
**
model_kwargs
)
@
classmethod
def
to_llm_from_model
(
cls
,
tenant_id
:
str
,
model
:
dict
,
streaming
:
bool
=
False
,
...
...
@@ -118,14 +119,29 @@ class LLMBuilder:
return
provider_service
.
get_credentials
(
model_name
)
@
classmethod
def
get_default_provider
(
cls
,
tenant_id
:
str
)
->
str
:
provider
=
BaseProvider
.
get_valid_provider
(
tenant_id
)
if
not
provider
:
raise
ProviderTokenNotInitError
()
if
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
:
provider_name
=
'openai'
else
:
provider_name
=
provider
.
provider_name
def
get_default_provider
(
cls
,
tenant_id
:
str
,
model_name
:
str
)
->
str
:
provider_name
=
llm_constant
.
models
[
model_name
]
if
provider_name
==
'openai'
:
# get the default provider (openai / azure_openai) for the tenant
openai_provider
=
BaseProvider
.
get_valid_provider
(
tenant_id
,
ProviderName
.
OPENAI
.
value
)
azure_openai_provider
=
BaseProvider
.
get_valid_provider
(
tenant_id
,
ProviderName
.
AZURE_OPENAI
.
value
)
provider
=
None
if
openai_provider
:
provider
=
openai_provider
elif
azure_openai_provider
:
provider
=
azure_openai_provider
if
not
provider
:
raise
ProviderTokenNotInitError
(
f
"No valid {provider_name} model provider credentials found. "
f
"Please go to Settings -> Model Provider to complete your provider credentials."
)
if
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
:
provider_name
=
'openai'
else
:
provider_name
=
provider
.
provider_name
return
provider_name
api/core/llm/provider/anthropic_provider.py
View file @
8226c765
from
typing
import
Optional
import
json
import
logging
from
typing
import
Optional
,
Union
import
anthropic
from
langchain.chat_models
import
ChatAnthropic
from
langchain.schema
import
HumanMessage
from
core
import
hosted_llm_credentials
from
core.llm.error
import
ProviderTokenNotInitError
from
core.llm.provider.base
import
BaseProvider
from
core.llm.provider.errors
import
ValidateFailedError
from
models.provider
import
ProviderName
class
AnthropicProvider
(
BaseProvider
):
def
get_models
(
self
,
model_id
:
Optional
[
str
]
=
None
)
->
list
[
dict
]:
credentials
=
self
.
get_credentials
(
model_id
)
# todo
return
[]
return
[
{
'id'
:
'claude-instant-1'
,
'name'
:
'claude-instant-1'
,
},
{
'id'
:
'claude-2'
,
'name'
:
'claude-2'
,
},
]
def
get_credentials
(
self
,
model_id
:
Optional
[
str
]
=
None
)
->
dict
:
return
self
.
get_provider_api_key
(
model_id
=
model_id
)
def
get_provider_name
(
self
):
return
ProviderName
.
ANTHROPIC
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
)
->
Union
[
str
|
dict
]:
"""
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
Returns the provider configs.
"""
return
{
'anthropic_api_key'
:
self
.
get_provider_api_key
(
model_id
=
model_id
)
}
try
:
config
=
self
.
get_provider_api_key
()
except
:
config
=
{
'anthropic_api_key'
:
''
}
def
get_provider_name
(
self
):
return
ProviderName
.
ANTHROPIC
\ No newline at end of file
if
obfuscated
:
if
not
config
.
get
(
'anthropic_api_key'
):
config
=
{
'anthropic_api_key'
:
''
}
config
[
'anthropic_api_key'
]
=
self
.
obfuscated_token
(
config
.
get
(
'anthropic_api_key'
))
return
config
return
config
def
get_encrypted_token
(
self
,
config
:
Union
[
dict
|
str
]):
"""
Returns the encrypted token.
"""
return
json
.
dumps
({
'anthropic_api_key'
:
self
.
encrypt_token
(
config
[
'anthropic_api_key'
])
})
def
get_decrypted_token
(
self
,
token
:
str
):
"""
Returns the decrypted token.
"""
config
=
json
.
loads
(
token
)
config
[
'anthropic_api_key'
]
=
self
.
decrypt_token
(
config
[
'anthropic_api_key'
])
return
config
def
get_token_type
(
self
):
return
dict
def
config_validate
(
self
,
config
:
Union
[
dict
|
str
]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider
=
BaseProvider
.
get_valid_provider
(
self
.
tenant_id
,
ProviderName
.
OPENAI
.
value
)
azure_openai_provider
=
BaseProvider
.
get_valid_provider
(
self
.
tenant_id
,
ProviderName
.
AZURE_OPENAI
.
value
)
provider
=
None
if
openai_provider
:
provider
=
openai_provider
elif
azure_openai_provider
:
provider
=
azure_openai_provider
if
not
provider
:
raise
ValidateFailedError
(
f
"OpenAI or Azure OpenAI provider must be configured first."
)
try
:
if
not
isinstance
(
config
,
dict
):
raise
ValueError
(
'Config must be a object.'
)
if
'anthropic_api_key'
not
in
config
:
raise
ValueError
(
'anthropic_api_key must be provided.'
)
chat_llm
=
ChatAnthropic
(
model
=
'claude-instant-1'
,
anthropic_api_key
=
config
[
'anthropic_api_key'
],
max_tokens_to_sample
=
10
,
temperature
=
0
,
default_request_timeout
=
60
)
messages
=
[
HumanMessage
(
content
=
"ping"
)
]
chat_llm
(
messages
)
except
(
anthropic
.
APIStatusError
,
anthropic
.
APIConnectionError
,
anthropic
.
RateLimitError
)
as
ex
:
raise
ValidateFailedError
(
f
"Anthropic: {ex.message}"
)
except
Exception
as
ex
:
logging
.
exception
(
'Anthropic config validation failed'
)
raise
ex
def
get_hosted_credentials
(
self
)
->
Union
[
str
|
dict
]:
if
not
hosted_llm_credentials
.
anthropic
or
not
hosted_llm_credentials
.
anthropic
.
api_key
:
raise
ProviderTokenNotInitError
(
f
"No valid {self.get_provider_name().value} model provider credentials found. "
f
"Please go to Settings -> Model Provider to complete your provider credentials."
)
return
{
'anthropic_api_key'
:
hosted_llm_credentials
.
anthropic
.
api_key
}
api/core/llm/provider/azure_provider.py
View file @
8226c765
...
...
@@ -81,7 +81,6 @@ class AzureProvider(BaseProvider):
return
config
def
get_token_type
(
self
):
# TODO: change to dict when implemented
return
dict
def
config_validate
(
self
,
config
:
Union
[
dict
|
str
]):
...
...
api/core/llm/provider/base.py
View file @
8226c765
...
...
@@ -2,7 +2,7 @@ import base64
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
Union
from
core
import
hosted_llm_credentials
from
core
.constant
import
llm_constant
from
core.llm.error
import
QuotaExceededError
,
ModelCurrentlyNotSupportError
,
ProviderTokenNotInitError
from
extensions.ext_database
import
db
from
libs
import
rsa
...
...
@@ -22,7 +22,10 @@ class BaseProvider(ABC):
"""
provider
=
self
.
get_provider
(
prefer_custom
)
if
not
provider
:
raise
ProviderTokenNotInitError
()
raise
ProviderTokenNotInitError
(
f
"No valid {llm_constant.models[model_id]} model provider credentials found. "
f
"Please go to Settings -> Model Provider to complete your provider credentials."
)
if
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
:
quota_used
=
provider
.
quota_used
if
provider
.
quota_used
is
not
None
else
0
...
...
@@ -46,7 +49,8 @@ class BaseProvider(ABC):
return
BaseProvider
.
get_valid_provider
(
self
.
tenant_id
,
self
.
get_provider_name
()
.
value
,
prefer_custom
)
@
classmethod
def
get_valid_provider
(
cls
,
tenant_id
:
str
,
provider_name
:
str
=
None
,
prefer_custom
:
bool
=
False
)
->
Optional
[
Provider
]:
def
get_valid_provider
(
cls
,
tenant_id
:
str
,
provider_name
:
str
=
None
,
prefer_custom
:
bool
=
False
)
->
Optional
[
Provider
]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
...
...
@@ -76,14 +80,11 @@ class BaseProvider(ABC):
else
:
return
None
def
get_hosted_credentials
(
self
)
->
str
:
if
self
.
get_provider_name
()
!=
ProviderName
.
OPENAI
:
raise
ProviderTokenNotInitError
()
if
not
hosted_llm_credentials
.
openai
or
not
hosted_llm_credentials
.
openai
.
api_key
:
raise
ProviderTokenNotInitError
()
return
hosted_llm_credentials
.
openai
.
api_key
def
get_hosted_credentials
(
self
)
->
Union
[
str
|
dict
]:
raise
ProviderTokenNotInitError
(
f
"No valid {self.get_provider_name().value} model provider credentials found. "
f
"Please go to Settings -> Model Provider to complete your provider credentials."
)
def
get_provider_configs
(
self
,
obfuscated
:
bool
=
False
)
->
Union
[
str
|
dict
]:
"""
...
...
api/core/llm/provider/openai_provider.py
View file @
8226c765
...
...
@@ -4,6 +4,8 @@ from typing import Optional, Union
import
openai
from
openai.error
import
AuthenticationError
,
OpenAIError
from
core
import
hosted_llm_credentials
from
core.llm.error
import
ProviderTokenNotInitError
from
core.llm.moderation
import
Moderation
from
core.llm.provider.base
import
BaseProvider
from
core.llm.provider.errors
import
ValidateFailedError
...
...
@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
except
Exception
as
ex
:
logging
.
exception
(
'OpenAI config validation failed'
)
raise
ex
def
get_hosted_credentials
(
self
)
->
Union
[
str
|
dict
]:
if
not
hosted_llm_credentials
.
openai
or
not
hosted_llm_credentials
.
openai
.
api_key
:
raise
ProviderTokenNotInitError
(
f
"No valid {self.get_provider_name().value} model provider credentials found. "
f
"Please go to Settings -> Model Provider to complete your provider credentials."
)
return
hosted_llm_credentials
.
openai
.
api_key
api/core/llm/streamable_azure_chat_open_ai.py
View file @
8226c765
from
langchain.callbacks.manager
import
Callback
ManagerForLLMRun
,
AsyncCallbackManagerForLLMRun
,
Callback
s
from
langchain.schema
import
BaseMessage
,
ChatResult
,
LLMResult
from
langchain.callbacks.manager
import
Callbacks
from
langchain.schema
import
BaseMessage
,
LLMResult
from
langchain.chat_models
import
AzureChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
from
pydantic
import
root_validator
from
core.llm.
error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
from
core.llm.
wrappers.openai_wrapper
import
handle_openai_exceptions
class
StreamableAzureChatOpenAI
(
AzureChatOpenAI
):
...
...
@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}
def
get_messages_tokens
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message
=
5
tokens_per_request
=
3
message_tokens
=
tokens_per_request
message_strs
=
''
for
message
in
messages
:
message_strs
+=
message
.
content
message_tokens
+=
tokens_per_message
# calc once
message_tokens
+=
self
.
get_num_tokens
(
message_strs
)
return
message_tokens
@
handle_llm_exceptions
@
handle_openai_exceptions
def
generate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
...
...
@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
)
->
LLMResult
:
return
super
()
.
generate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
classmethod
def
get_kwargs_from_model_params
(
cls
,
params
:
dict
):
model_kwargs
=
{
'top_p'
:
params
.
get
(
'top_p'
,
1
),
'frequency_penalty'
:
params
.
get
(
'frequency_penalty'
,
0
),
'presence_penalty'
:
params
.
get
(
'presence_penalty'
,
0
),
}
del
params
[
'top_p'
]
del
params
[
'frequency_penalty'
]
del
params
[
'presence_penalty'
]
params
[
'model_kwargs'
]
=
model_kwargs
return
params
api/core/llm/streamable_azure_open_ai.py
View file @
8226c765
...
...
@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
from
pydantic
import
root_validator
from
core.llm.
error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
from
core.llm.
wrappers.openai_wrapper
import
handle_openai_exceptions
class
StreamableAzureOpenAI
(
AzureOpenAI
):
...
...
@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
@
handle_
llm
_exceptions
@
handle_
openai
_exceptions
def
generate
(
self
,
prompts
:
List
[
str
],
...
...
@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
)
->
LLMResult
:
return
super
()
.
generate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
classmethod
def
get_kwargs_from_model_params
(
cls
,
params
:
dict
):
return
params
api/core/llm/streamable_chat_anthropic.py
0 → 100644
View file @
8226c765
from
typing
import
List
,
Optional
,
Any
,
Dict
from
langchain.callbacks.manager
import
Callbacks
from
langchain.chat_models
import
ChatAnthropic
from
langchain.schema
import
BaseMessage
,
LLMResult
from
core.llm.wrappers.anthropic_wrapper
import
handle_anthropic_exceptions
class
StreamableChatAnthropic
(
ChatAnthropic
):
"""
Wrapper around Anthropic's large language model.
"""
@
handle_anthropic_exceptions
def
generate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
*
,
tags
:
Optional
[
List
[
str
]]
=
None
,
metadata
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
super
()
.
generate
(
messages
,
stop
,
callbacks
,
tags
=
tags
,
metadata
=
metadata
,
**
kwargs
)
@
classmethod
def
get_kwargs_from_model_params
(
cls
,
params
:
dict
):
params
[
'model'
]
=
params
.
get
(
'model_name'
)
del
params
[
'model_name'
]
params
[
'max_tokens_to_sample'
]
=
params
.
get
(
'max_tokens'
)
del
params
[
'max_tokens'
]
del
params
[
'frequency_penalty'
]
del
params
[
'presence_penalty'
]
return
params
api/core/llm/streamable_chat_open_ai.py
View file @
8226c765
...
...
@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
from
pydantic
import
root_validator
from
core.llm.
error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
from
core.llm.
wrappers.openai_wrapper
import
handle_openai_exceptions
class
StreamableChatOpenAI
(
ChatOpenAI
):
...
...
@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}
def
get_messages_tokens
(
self
,
messages
:
List
[
BaseMessage
])
->
int
:
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message
=
5
tokens_per_request
=
3
message_tokens
=
tokens_per_request
message_strs
=
''
for
message
in
messages
:
message_strs
+=
message
.
content
message_tokens
+=
tokens_per_message
# calc once
message_tokens
+=
self
.
get_num_tokens
(
message_strs
)
return
message_tokens
@
handle_llm_exceptions
@
handle_openai_exceptions
def
generate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
...
...
@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
)
->
LLMResult
:
return
super
()
.
generate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
messages
:
List
[
List
[
BaseMessage
]],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
messages
,
stop
,
callbacks
,
**
kwargs
)
@
classmethod
def
get_kwargs_from_model_params
(
cls
,
params
:
dict
):
model_kwargs
=
{
'top_p'
:
params
.
get
(
'top_p'
,
1
),
'frequency_penalty'
:
params
.
get
(
'frequency_penalty'
,
0
),
'presence_penalty'
:
params
.
get
(
'presence_penalty'
,
0
),
}
del
params
[
'top_p'
]
del
params
[
'frequency_penalty'
]
del
params
[
'presence_penalty'
]
params
[
'model_kwargs'
]
=
model_kwargs
return
params
api/core/llm/streamable_open_ai.py
View file @
8226c765
...
...
@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
from
langchain
import
OpenAI
from
pydantic
import
root_validator
from
core.llm.
error_handle_wraps
import
handle_llm_exceptions
,
handle_llm_exceptions_async
from
core.llm.
wrappers.openai_wrapper
import
handle_openai_exceptions
class
StreamableOpenAI
(
OpenAI
):
...
...
@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
"organization"
:
self
.
openai_organization
if
self
.
openai_organization
else
None
,
}}
@
handle_
llm
_exceptions
@
handle_
openai
_exceptions
def
generate
(
self
,
prompts
:
List
[
str
],
...
...
@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
)
->
LLMResult
:
return
super
()
.
generate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
handle_llm_exceptions_async
async
def
agenerate
(
self
,
prompts
:
List
[
str
],
stop
:
Optional
[
List
[
str
]]
=
None
,
callbacks
:
Callbacks
=
None
,
**
kwargs
:
Any
,
)
->
LLMResult
:
return
await
super
()
.
agenerate
(
prompts
,
stop
,
callbacks
,
**
kwargs
)
@
classmethod
def
get_kwargs_from_model_params
(
cls
,
params
:
dict
):
return
params
api/core/llm/whisper.py
View file @
8226c765
import
openai
from
core.llm.wrappers.openai_wrapper
import
handle_openai_exceptions
from
models.provider
import
ProviderName
from
core.llm.error_handle_wraps
import
handle_llm_exceptions
from
core.llm.provider.base
import
BaseProvider
...
...
@@ -13,7 +14,7 @@ class Whisper:
self
.
client
=
openai
.
Audio
self
.
credentials
=
provider
.
get_credentials
()
@
handle_
llm
_exceptions
@
handle_
openai
_exceptions
def
transcribe
(
self
,
file
):
return
self
.
client
.
transcribe
(
model
=
'whisper-1'
,
...
...
api/core/llm/wrappers/anthropic_wrapper.py
0 → 100644
View file @
8226c765
import
logging
from
functools
import
wraps
import
anthropic
from
core.llm.error
import
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
LLMRateLimitError
,
LLMAuthorizationError
,
\
LLMBadRequestError
def
handle_anthropic_exceptions
(
func
):
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
func
(
*
args
,
**
kwargs
)
except
anthropic
.
APIConnectionError
as
e
:
logging
.
exception
(
"Failed to connect to Anthropic API."
)
raise
LLMAPIConnectionError
(
f
"Anthropic: The server could not be reached, cause: {e.__cause__}"
)
except
anthropic
.
RateLimitError
:
raise
LLMRateLimitError
(
"Anthropic: A 429 status code was received; we should back off a bit."
)
except
anthropic
.
AuthenticationError
as
e
:
raise
LLMAuthorizationError
(
f
"Anthropic: {e.message}"
)
except
anthropic
.
BadRequestError
as
e
:
raise
LLMBadRequestError
(
f
"Anthropic: {e.message}"
)
except
anthropic
.
APIStatusError
as
e
:
raise
LLMAPIUnavailableError
(
f
"Anthropic: code: {e.status_code}, cause: {e.message}"
)
return
wrapper
api/core/llm/
error_handle_wraps
.py
→
api/core/llm/
wrappers/openai_wrapper
.py
View file @
8226c765
...
...
@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
LLMBadRequestError
def
handle_
llm
_exceptions
(
func
):
def
handle_
openai
_exceptions
(
func
):
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
...
...
@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
raise
LLMBadRequestError
(
e
.
__class__
.
__name__
+
":"
+
str
(
e
))
return
wrapper
def
handle_llm_exceptions_async
(
func
):
@
wraps
(
func
)
async
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
await
func
(
*
args
,
**
kwargs
)
except
openai
.
error
.
InvalidRequestError
as
e
:
logging
.
exception
(
"Invalid request to OpenAI API."
)
raise
LLMBadRequestError
(
str
(
e
))
except
openai
.
error
.
APIConnectionError
as
e
:
logging
.
exception
(
"Failed to connect to OpenAI API."
)
raise
LLMAPIConnectionError
(
e
.
__class__
.
__name__
+
":"
+
str
(
e
))
except
(
openai
.
error
.
APIError
,
openai
.
error
.
ServiceUnavailableError
,
openai
.
error
.
Timeout
)
as
e
:
logging
.
exception
(
"OpenAI service unavailable."
)
raise
LLMAPIUnavailableError
(
e
.
__class__
.
__name__
+
":"
+
str
(
e
))
except
openai
.
error
.
RateLimitError
as
e
:
raise
LLMRateLimitError
(
str
(
e
))
except
openai
.
error
.
AuthenticationError
as
e
:
raise
LLMAuthorizationError
(
str
(
e
))
except
openai
.
error
.
OpenAIError
as
e
:
raise
LLMBadRequestError
(
e
.
__class__
.
__name__
+
":"
+
str
(
e
))
return
wrapper
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
View file @
8226c765
from
typing
import
Any
,
List
,
Dict
,
Union
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.schema
import
get_buffer_string
,
BaseMessage
,
HumanMessage
,
AIMessage
from
langchain.schema
import
get_buffer_string
,
BaseMessage
,
HumanMessage
,
AIMessage
,
BaseLanguageModel
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.llm.streamable_open_ai
import
StreamableOpenAI
...
...
@@ -12,8 +12,8 @@ from models.model import Conversation, Message
class
ReadOnlyConversationTokenDBBufferSharedMemory
(
BaseChatMemory
):
conversation
:
Conversation
human_prefix
:
str
=
"Human"
ai_prefix
:
str
=
"A
I
"
llm
:
Union
[
StreamableChatOpenAI
|
StreamableOpenAI
]
ai_prefix
:
str
=
"A
ssistant
"
llm
:
BaseLanguageModel
memory_key
:
str
=
"chat_history"
max_token_limit
:
int
=
2000
message_limit
:
int
=
10
...
...
@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
return
chat_messages
# prune the chat message if it exceeds the max token limit
curr_buffer_length
=
self
.
llm
.
get_
messages_token
s
(
chat_messages
)
curr_buffer_length
=
self
.
llm
.
get_
num_tokens_from_message
s
(
chat_messages
)
if
curr_buffer_length
>
self
.
max_token_limit
:
pruned_memory
=
[]
while
curr_buffer_length
>
self
.
max_token_limit
and
chat_messages
:
pruned_memory
.
append
(
chat_messages
.
pop
(
0
))
curr_buffer_length
=
self
.
llm
.
get_
messages_token
s
(
chat_messages
)
curr_buffer_length
=
self
.
llm
.
get_
num_tokens_from_message
s
(
chat_messages
)
return
chat_messages
...
...
api/core/tool/dataset_index_tool.py
View file @
8226c765
...
...
@@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
else
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
...
...
@@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
async
def
_arun
(
self
,
tool_input
:
str
)
->
str
:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
self
.
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
),
model_provider
=
LLMBuilder
.
get_default_provider
(
self
.
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
...
...
api/requirements.txt
View file @
8226c765
...
...
@@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
gunicorn~=20.1.0
gevent~=22.10.2
langchain==0.0.2
09
langchain==0.0.2
30
openai~=0.27.5
psycopg2-binary~=2.9.6
pycryptodome==3.17
...
...
@@ -35,3 +35,4 @@ docx2txt==0.8
pypdfium2==4.16.0
resend~=0.5.1
pyjwt~=2.6.0
anthropic~=0.3.4
api/services/app_model_config_service.py
View file @
8226c765
...
...
@@ -6,6 +6,30 @@ from models.account import Account
from
services.dataset_service
import
DatasetService
from
core.llm.llm_builder
import
LLMBuilder
MODEL_PROVIDERS
=
[
'openai'
,
'anthropic'
,
]
MODELS_BY_APP_MODE
=
{
'chat'
:
[
'claude-instant-1'
,
'claude-2'
,
'gpt-4'
,
'gpt-4-32k'
,
'gpt-3.5-turbo'
,
'gpt-3.5-turbo-16k'
,
],
'completion'
:
[
'claude-instant-1'
,
'claude-2'
,
'gpt-4'
,
'gpt-4-32k'
,
'gpt-3.5-turbo'
,
'gpt-3.5-turbo-16k'
,
'text-davinci-003'
,
]
}
class
AppModelConfigService
:
@
staticmethod
...
...
@@ -125,7 +149,7 @@ class AppModelConfigService:
if
not
isinstance
(
config
[
"speech_to_text"
][
"enabled"
],
bool
):
raise
ValueError
(
"enabled in speech_to_text must be of boolean type"
)
provider_name
=
LLMBuilder
.
get_default_provider
(
account
.
current_tenant_id
)
provider_name
=
LLMBuilder
.
get_default_provider
(
account
.
current_tenant_id
,
'whisper-1'
)
if
config
[
"speech_to_text"
][
"enabled"
]
and
provider_name
!=
'openai'
:
raise
ValueError
(
"provider not support speech to text"
)
...
...
@@ -153,14 +177,14 @@ class AppModelConfigService:
raise
ValueError
(
"model must be of object type"
)
# model.provider
if
'provider'
not
in
config
[
"model"
]
or
config
[
"model"
][
"provider"
]
!=
"openai"
:
raise
ValueError
(
"model.provider must be 'openai'
"
)
if
'provider'
not
in
config
[
"model"
]
or
config
[
"model"
][
"provider"
]
not
in
MODEL_PROVIDERS
:
raise
ValueError
(
f
"model.provider is required and must be in {str(MODEL_PROVIDERS)}
"
)
# model.name
if
'name'
not
in
config
[
"model"
]:
raise
ValueError
(
"model.name is required"
)
if
config
[
"model"
][
"name"
]
not
in
llm_constant
.
models_by_mode
[
mode
]:
if
config
[
"model"
][
"name"
]
not
in
MODELS_BY_APP_MODE
[
mode
]:
raise
ValueError
(
"model.name must be in the specified model list"
)
# model.completion_params
...
...
api/services/audio_service.py
View file @
8226c765
...
...
@@ -27,7 +27,7 @@ class AudioService:
message
=
f
"Audio size larger than {FILE_SIZE} mb"
raise
AudioTooLargeServiceError
(
message
)
provider_name
=
LLMBuilder
.
get_default_provider
(
tenant_id
)
provider_name
=
LLMBuilder
.
get_default_provider
(
tenant_id
,
'whisper-1'
)
if
provider_name
!=
ProviderName
.
OPENAI
.
value
:
raise
ProviderNotSupportSpeechToTextServiceError
()
...
...
@@ -37,8 +37,3 @@ class AudioService:
buffer
.
name
=
'temp.mp3'
return
Whisper
(
provider_service
.
provider
)
.
transcribe
(
buffer
)
\ No newline at end of file
api/services/hit_testing_service.py
View file @
8226c765
...
...
@@ -31,7 +31,7 @@ class HitTestingService:
model_credentials
=
LLMBuilder
.
get_model_credentials
(
tenant_id
=
dataset
.
tenant_id
,
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
),
model_provider
=
LLMBuilder
.
get_default_provider
(
dataset
.
tenant_id
,
'text-embedding-ada-002'
),
model_name
=
'text-embedding-ada-002'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment